def backward(self): grad_pos, grad_vel = self.grad_out() # Run for 100 iterations it = m.UInt32(100) loop = m.Loop(it, grad_pos, grad_vel) n = ek.width(grad_pos) while loop.cond(it > 0): # Retrieve loop variables, reverse chronological order it -= 1 index = it * n + ek.arange(m.UInt32, n) pos = ek.gather(m.Array2f, self.temp_pos, index) vel = ek.gather(m.Array2f, self.temp_vel, index) # Differentiate loop body in reverse mode ek.enable_grad(pos, vel) pos_out, vel_out = self.timestep(pos, vel) ek.set_grad(pos_out, grad_pos) ek.set_grad(vel_out, grad_vel) ek.enqueue(pos_out, vel_out) ek.traverse(m.Float, reverse=True) # Update loop variables grad_pos.assign(ek.grad(pos)) grad_vel.assign(ek.grad(vel)) self.set_grad_in('pos', grad_pos) self.set_grad_in('vel', grad_vel)
def backward(self): grad_pos, grad_vel = self.grad_out() pos, vel = self.pos, self.vel # Run for 100 iterations it = m.UInt32(0) loop = m.Loop(it, pos, vel, grad_pos, grad_vel) while loop.cond(it < 100): # Take reverse step in time pos_rev, vel_rev = self.timestep(pos, vel, dt=-0.02) pos.assign(pos_rev) vel.assign(vel_rev) # Take a forward step in time, keep track of derivatives ek.enable_grad(pos_rev, vel_rev) pos_fwd, vel_fwd = self.timestep(pos_rev, vel_rev, dt=0.02) ek.set_grad(pos_fwd, grad_pos) ek.set_grad(vel_fwd, grad_vel) ek.enqueue(pos_fwd, vel_fwd) ek.traverse(m.Float, reverse=True) grad_pos.assign(ek.grad(pos_rev)) grad_vel.assign(ek.grad(vel_rev)) it += 1 self.set_grad_in('pos', grad_pos) self.set_grad_in('vel', grad_vel)
def backward(self): grad_pos, grad_vel = self.grad_out() pos, vel = self.pos, self.vel # Run for 100 iterations it = m.UInt32(0) loop = m.Loop("backward", lambda: (it, pos, vel, grad_pos, grad_vel)) while loop(it < 100): # Take reverse step in time pos, vel = self.timestep(pos, vel, dt=-0.02) # Take a forward step in time, keep track of derivatives pos_rev, vel_rev = m.Array2f(pos), m.Array2f(vel) ek.enable_grad(pos_rev, vel_rev) pos_fwd, vel_fwd = self.timestep(pos_rev, vel_rev, dt=0.02) ek.set_grad(pos_fwd, grad_pos) ek.set_grad(vel_fwd, grad_vel) ek.enqueue(pos_fwd, vel_fwd) ek.traverse(m.Float, reverse=True) grad_pos = ek.grad(pos_rev) grad_vel = ek.grad(vel_rev) it += 1 self.set_grad_in('pos', grad_pos) self.set_grad_in('vel', grad_vel)
def test43_custom_reverse(m): d = m.Array3f(1, 2, 3) ek.enable_grad(d) d2 = ek.custom(Normalize, d) ek.set_grad(d2, m.Array3f(5, 6, 7)) ek.enqueue(d2) ek.traverse(m.Float, reverse=True) assert ek.allclose(ek.grad(d), m.Array3f(0.610883, 0.152721, -0.305441))
def test44_custom_forward(m): d = m.Array3f(1, 2, 3) ek.enable_grad(d) d2 = ek.custom(Normalize, d) ek.set_grad(d, m.Array3f(5, 6, 7)) ek.enqueue(d) ek.traverse(m.Float, reverse=False, retain_graph=True) assert ek.grad(d) == 0 ek.set_grad(d, m.Array3f(5, 6, 7)) assert ek.allclose(ek.grad(d2), m.Array3f(0.610883, 0.152721, -0.305441)) ek.enqueue(d) ek.traverse(m.Float, reverse=False, retain_graph=False) assert ek.allclose(ek.grad(d2), m.Array3f(0.610883, 0.152721, -0.305441) * 2)
def test21_scatter_add_fwd(m): for i in range(3): idx1 = ek.arange(m.UInt, 5) idx2 = ek.arange(m.UInt, 4) + 3 x = ek.linspace(m.Float, 0, 1, 5) y = ek.linspace(m.Float, 1, 2, 4) buf = ek.zero(m.Float, 10) if i % 2 == 0: ek.enable_grad(buf) ek.set_grad(buf, 1) if i // 2 == 0: ek.enable_grad(x, y) ek.set_grad(x, 1) ek.set_grad(y, 1) x.label = "x" y.label = "y" buf.label = "buf" buf2 = m.Float(buf) ek.scatter_add(buf2, x, idx1) ek.scatter_add(buf2, y, idx2) s = ek.dot_async(buf2, buf2) if i % 2 == 0: ek.enqueue(buf) if i // 2 == 0: ek.enqueue(x, y) ek.traverse(m.Float, reverse=False) # Verified against Mathematica assert ek.allclose(ek.detach(s), 15.5972) assert ek.allclose(ek.grad(s), (25.1667 if i // 2 == 0 else 0) + (17 if i % 2 == 0 else 0))
def test02_add_fwd(m): if True: a, b = m.Float(1), m.Float(2) ek.enable_grad(a, b) c = 2 * a + b ek.forward(a, retain_graph=True) assert ek.grad(c) == 2 ek.set_grad(c, 101) ek.forward(b) assert ek.grad(c) == 102 if True: a, b = m.Float(1), m.Float(2) ek.enable_grad(a, b) c = 2 * a + b ek.set_grad(a, 1.0) ek.enqueue(a) ek.traverse(m.Float, retain_graph=True, reverse=False) assert ek.grad(c) == 2 assert ek.grad(a) == 0 ek.set_grad(a, 1.0) ek.enqueue(a) ek.traverse(m.Float, retain_graph=False, reverse=False) assert ek.grad(c) == 4