def test22_scatter_fwd(m): x = m.Float(4.0) ek.enable_grad(x) values = x * x * ek.linspace(m.Float, 1, 4, 4) idx = 2 * ek.arange(m.UInt32, 4) buf = ek.zero(m.Float, 10) ek.scatter(buf, values, idx) assert ek.grad_enabled(buf) ref = [16.0, 0.0, 32.0, 0.0, 48.0, 0.0, 64.0, 0.0, 0.0, 0.0] assert ek.allclose(buf, ref) ek.forward(x, retain_graph=True) grad = ek.grad(buf) ref_grad = [8.0, 0.0, 16.0, 0.0, 24.0, 0.0, 32.0, 0.0, 0.0, 0.0] assert ek.allclose(grad, ref_grad) # Overwrite first value with non-diff value, resulting gradient entry should be 0 y = m.Float(3) idx = m.UInt32(0) ek.scatter(buf, y, idx) ref = [3.0, 0.0, 32.0, 0.0, 48.0, 0.0, 64.0, 0.0, 0.0, 0.0] assert ek.allclose(buf, ref) ek.forward(x) grad = ek.grad(buf) ref_grad = [0.0, 0.0, 16.0, 0.0, 24.0, 0.0, 32.0, 0.0, 0.0, 0.0] assert ek.allclose(grad, ref_grad)
def backward(self): grad_pos, grad_vel = self.grad_out() pos, vel = self.pos, self.vel # Run for 100 iterations it = m.UInt32(0) loop = m.Loop(it, pos, vel, grad_pos, grad_vel) while loop.cond(it < 100): # Take reverse step in time pos_rev, vel_rev = self.timestep(pos, vel, dt=-0.02) pos.assign(pos_rev) vel.assign(vel_rev) # Take a forward step in time, keep track of derivatives ek.enable_grad(pos_rev, vel_rev) pos_fwd, vel_fwd = self.timestep(pos_rev, vel_rev, dt=0.02) ek.set_grad(pos_fwd, grad_pos) ek.set_grad(vel_fwd, grad_vel) ek.enqueue(pos_fwd, vel_fwd) ek.traverse(m.Float, reverse=True) grad_pos.assign(ek.grad(pos_rev)) grad_vel.assign(ek.grad(vel_rev)) it += 1 self.set_grad_in('pos', grad_pos) self.set_grad_in('vel', grad_vel)
def test01_add_rev(m): a, b = m.Float(1), m.Float(2) ek.enable_grad(a, b) c = 2 * a + b ek.backward(c) assert ek.grad(a) == 2 assert ek.grad(b) == 1
def backward(self): grad_pos, grad_vel = self.grad_out() pos, vel = self.pos, self.vel # Run for 100 iterations it = m.UInt32(0) loop = m.Loop("backward", lambda: (it, pos, vel, grad_pos, grad_vel)) while loop(it < 100): # Take reverse step in time pos, vel = self.timestep(pos, vel, dt=-0.02) # Take a forward step in time, keep track of derivatives pos_rev, vel_rev = m.Array2f(pos), m.Array2f(vel) ek.enable_grad(pos_rev, vel_rev) pos_fwd, vel_fwd = self.timestep(pos_rev, vel_rev, dt=0.02) ek.set_grad(pos_fwd, grad_pos) ek.set_grad(vel_fwd, grad_vel) ek.enqueue(pos_fwd, vel_fwd) ek.traverse(m.Float, reverse=True) grad_pos = ek.grad(pos_rev) grad_vel = ek.grad(vel_rev) it += 1 self.set_grad_in('pos', grad_pos) self.set_grad_in('vel', grad_vel)
def test04_div(m): a, b = m.Float(2), m.Float(3) ek.enable_grad(a, b) d = a / b ek.backward(d) assert ek.allclose(ek.grad(a), 1.0 / 3.0) assert ek.allclose(ek.grad(b), -2.0 / 9.0)
def test06_hsum_0_fwd(m): x = ek.linspace(m.Float, 0, 1, 10) ek.enable_grad(x) y = ek.hsum_async(x * x) ek.forward(x) assert len(y) == 1 and ek.allclose(ek.detach(y), 95.0 / 27.0) assert len(ek.grad(y)) == 1 and ek.allclose(ek.grad(y), 10)
def backward(self): grad_pos, grad_vel = self.grad_out() # Run for 100 iterations it = m.UInt32(100) loop = m.Loop(it, grad_pos, grad_vel) n = ek.width(grad_pos) while loop.cond(it > 0): # Retrieve loop variables, reverse chronological order it -= 1 index = it * n + ek.arange(m.UInt32, n) pos = ek.gather(m.Array2f, self.temp_pos, index) vel = ek.gather(m.Array2f, self.temp_vel, index) # Differentiate loop body in reverse mode ek.enable_grad(pos, vel) pos_out, vel_out = self.timestep(pos, vel) ek.set_grad(pos_out, grad_pos) ek.set_grad(vel_out, grad_vel) ek.enqueue(pos_out, vel_out) ek.traverse(m.Float, reverse=True) # Update loop variables grad_pos.assign(ek.grad(pos)) grad_vel.assign(ek.grad(vel)) self.set_grad_in('pos', grad_pos) self.set_grad_in('vel', grad_vel)
def test03_sub_mul(m): a, b, c = m.Float(2), m.Float(3), m.Float(4) ek.enable_grad(a, b, c) d = a * b - c ek.backward(d) assert ek.grad(a) == ek.detach(b) assert ek.grad(b) == ek.detach(a) assert ek.grad(c) == -1
def test25_pow(m): x = ek.linspace(m.Float, 1, 10, 10) y = ek.full(m.Float, 2.0, 10) ek.enable_grad(x, y) z = x**y ek.backward(z) assert ek.allclose(ek.grad(x), ek.detach(x) * 2) assert ek.allclose( ek.grad(y), m.Float(0., 2.77259, 9.88751, 22.1807, 40.2359, 64.5033, 95.3496, 133.084, 177.975, 230.259))
def test41_replace_grad(m): x = m.Array3f(1, 2, 3) y = m.Array3f(3, 2, 1) ek.enable_grad(x, y) x2 = x * x y2 = y * y z = ek.replace_grad(x2, y2) z2 = z * z ek.backward(z2) assert ek.allclose(z2, [1, 16, 81]) assert ek.grad(x) == 0 assert ek.allclose(ek.grad(y), [12, 32, 36])
def test44_custom_forward(m): d = m.Array3f(1, 2, 3) ek.enable_grad(d) d2 = ek.custom(Normalize, d) ek.set_grad(d, m.Array3f(5, 6, 7)) ek.enqueue(d) ek.traverse(m.Float, reverse=False, retain_graph=True) assert ek.grad(d) == 0 ek.set_grad(d, m.Array3f(5, 6, 7)) assert ek.allclose(ek.grad(d2), m.Array3f(0.610883, 0.152721, -0.305441)) ek.enqueue(d) ek.traverse(m.Float, reverse=False, retain_graph=False) assert ek.allclose(ek.grad(d2), m.Array3f(0.610883, 0.152721, -0.305441) * 2)
def test19_gather_fwd(m): x = ek.linspace(m.Float, -1, 1, 10) ek.enable_grad(x) y = ek.gather(m.Float, x * x, m.UInt(1, 1, 2, 3)) ek.forward(x) ref = [-1.55556, -1.55556, -1.11111, -0.666667] assert ek.allclose(ek.grad(y), ref)
def test15_abs(m): x = m.Float(-2, 2) ek.enable_grad(x) y = ek.abs(x) ek.backward(y) assert ek.allclose(ek.detach(y), [2, 2]) assert ek.allclose(ek.grad(x), [-1, 1])
def test14_rsqrt(m): x = m.Float(1, .25, 0.0625) ek.enable_grad(x) y = ek.rsqrt(x) ek.backward(y) assert ek.allclose(ek.detach(y), [1, 2, 4]) assert ek.allclose(ek.grad(x), [-.5, -4, -32])
def test13_sqrt(m): x = m.Float(1, 4, 16) ek.enable_grad(x) y = ek.sqrt(x) ek.backward(y) assert ek.allclose(ek.detach(y), [1, 2, 4]) assert ek.allclose(ek.grad(x), [.5, .25, .125])
def test12_hmax_fwd(m): x = m.Float(1, 2, 8, 5, 8) ek.enable_grad(x) y = ek.hmax_async(x) ek.forward(x) assert len(y) == 1 and ek.allclose(y[0], 8) assert ek.allclose(ek.grad(y), [2]) # Approximation
def test11_hmax_rev(m): x = m.Float(1, 2, 8, 5, 8) ek.enable_grad(x) y = ek.hmax_async(x) ek.backward(y) assert len(y) == 1 and ek.allclose(y[0], 8) assert ek.allclose(ek.grad(x), [0, 0, 1, 0, 1])
def test11_hprod(m): x = m.Float(1, 2, 5, 8) ek.enable_grad(x) y = ek.hprod_async(x) ek.backward(y) assert len(y) == 1 and ek.allclose(y[0], 80) assert ek.allclose(ek.grad(x), [80, 40, 16, 10])
def test17_cos(m): x = ek.linspace(m.Float, 0.01, 10, 10) ek.enable_grad(x) y = ek.cos(x) ek.backward(y) assert ek.allclose(ek.detach(y), ek.cos(ek.detach(x))) assert ek.allclose(ek.grad(x), -ek.sin(ek.detach(x)))
def test49_eager_fwd(m): with EagerMode(): x = m.Float(1) ek.enable_grad(x) ek.set_grad(x, 10) y = x * x * x assert ek.grad(y) == 30
def test51_scatter_reduce_fwd_eager(m): with EagerMode(): for i in range(3): idx1 = ek.arange(m.UInt, 5) idx2 = ek.arange(m.UInt, 4) + 3 x = ek.linspace(m.Float, 0, 1, 5) y = ek.linspace(m.Float, 1, 2, 4) buf = ek.zero(m.Float, 10) if i % 2 == 0: ek.enable_grad(buf) ek.set_grad(buf, 1) if i // 2 == 0: ek.enable_grad(x, y) ek.set_grad(x, 1) ek.set_grad(y, 1) x.label = "x" y.label = "y" buf.label = "buf" buf2 = m.Float(buf) ek.scatter_reduce(ek.ReduceOp.Add, buf2, x, idx1) ek.scatter_reduce(ek.ReduceOp.Add, buf2, y, idx2) s = ek.dot_async(buf2, buf2) # Verified against Mathematica assert ek.allclose(ek.detach(s), 15.5972) assert ek.allclose(ek.grad(s), (25.1667 if i // 2 == 0 else 0) + (17 if i % 2 == 0 else 0))
def test05_hsum_0_rev(m): x = ek.linspace(m.Float, 0, 1, 10) ek.enable_grad(x) y = ek.hsum_async(x * x) ek.backward(y) assert len(y) == 1 and ek.allclose(y, 95.0 / 27.0) assert ek.allclose(ek.grad(x), 2 * ek.detach(x))
def test23_exp(m): x = ek.linspace(m.Float, 0, 1, 10) ek.enable_grad(x) y = ek.exp(x * x) ek.backward(y) exp_x = ek.exp(ek.sqr(ek.detach(x))) assert ek.allclose(y, exp_x) assert ek.allclose(ek.grad(x), 2 * ek.detach(x) * exp_x)
def test09_hsum_2_rev(m): x = ek.linspace(m.Float, 0, 1, 11) ek.enable_grad(x) z = ek.hsum_async(ek.hsum_async(x * x) * x * x) ek.backward(z) assert ek.allclose( ek.grad(x), [0., 1.54, 3.08, 4.62, 6.16, 7.7, 9.24, 10.78, 12.32, 13.86, 15.4])
def test18_gather(m): x = ek.linspace(m.Float, -1, 1, 10) ek.enable_grad(x) y = ek.gather(m.Float, x * x, m.UInt(1, 1, 2, 3)) z = ek.hsum_async(y) ek.backward(z) ref = [0, -1.55556 * 2, -1.11111, -0.666667, 0, 0, 0, 0, 0, 0] assert ek.allclose(ek.grad(x), ref)
def test50_gather_fwd_eager(m): with EagerMode(): x = ek.linspace(m.Float, -1, 1, 10) ek.enable_grad(x) ek.set_grad(x, 1) y = ek.gather(m.Float, x * x, m.UInt(1, 1, 2, 3)) ref = [-1.55556, -1.55556, -1.11111, -0.666667] assert ek.allclose(ek.grad(y), ref)
def test43_custom_reverse(m): d = m.Array3f(1, 2, 3) ek.enable_grad(d) d2 = ek.custom(Normalize, d) ek.set_grad(d2, m.Array3f(5, 6, 7)) ek.enqueue(d2) ek.traverse(m.Float, reverse=True) assert ek.allclose(ek.grad(d), m.Array3f(0.610883, 0.152721, -0.305441))
def test24_log(m): x = ek.linspace(m.Float, 0.01, 1, 10) ek.enable_grad(x) y = ek.log(x * x) ek.backward(y) log_x = ek.log(ek.sqr(ek.detach(x))) assert ek.allclose(y, log_x) assert ek.allclose(ek.grad(x), 2 / ek.detach(x))
def test40_safe_functions(m): x = ek.linspace(m.Float, 0, 1, 10) y = ek.linspace(m.Float, -1, 1, 10) z = ek.linspace(m.Float, -1, 1, 10) ek.enable_grad(x, y, z) x2 = ek.safe_sqrt(x) y2 = ek.safe_acos(y) z2 = ek.safe_asin(z) ek.backward(x2) ek.backward(y2) ek.backward(z2) assert ek.grad(x)[0] == 0 assert ek.allclose(ek.grad(x)[1], .5 / ek.sqrt(1 / 9)) assert x[0] == 0 assert ek.all(ek.isfinite(ek.grad(x))) assert ek.all(ek.isfinite(ek.grad(y))) assert ek.all(ek.isfinite(ek.grad(z)))
def test42_suspend_resume(m): x = m.Array3f(1, 2, 3) y = m.Array3f(3, 2, 1) ek.enable_grad(x, y) assert ek.grad_enabled(x) and ek.grad_enabled(y) assert not ek.grad_suspended(x) and not ek.grad_suspended(y) ek.suspend_grad(x, y) assert not ek.grad_enabled(x) and not ek.grad_enabled(y) assert ek.grad_suspended(x) and ek.grad_suspended(y) b = x * y ek.resume_grad(x, y) assert ek.grad_enabled(x) and ek.grad_enabled(y) assert not ek.grad_suspended(x) and not ek.grad_suspended(y) c = x * y ek.backward(c) assert ek.grad(x) == ek.detach(y) assert ek.grad(y) == ek.detach(x) ek.suspend_grad(x, y) # validate reference counting of suspended variables