def test45_diff_loop(m): def mcint(a, b, f, sample_count=100000): rng = m.PCG32() i = m.UInt32(0) result = m.Float(0) l = m.Loop(i, rng, result) while l.cond(i < sample_count): result += f(ek.lerp(a, b, rng.next_float32())) i += 1 return result * (b - a) / sample_count class EllipticK(ek.CustomOp): # --- Internally used utility methods --- # Integrand of the 'K' function def K(self, x, m_): return ek.rsqrt(1 - m_ * ek.sqr(ek.sin(x))) # Derivative of the above with respect to 'm' def dK(self, x, m_): m_ = m.Float(m_) # Convert 'm' to differentiable type ek.enable_grad(m_) y = self.K(x, m_) ek.forward(m_) return ek.grad(y) # Monte Carlo integral of dK, used in forward/reverse pass def eval_grad(self): return mcint(a=0, b=ek.Pi / 2, f=lambda x: self.dK(x, self.m_)) # --- CustomOp interface --- def eval(self, m_): self.m_ = m_ # Stash 'm' for later return mcint(a=0, b=ek.Pi / 2, f=lambda x: self.K(x, self.m_)) def forward(self): self.set_grad_out(self.grad_in('m_') * self.eval_grad()) def elliptic_k(m_): return ek.custom(EllipticK, m_) ek.enable_flag(ek.JitFlag.RecordLoops) x = m.Float(0.5) ek.enable_grad(x) y = elliptic_k(x) ek.forward(x) assert ek.allclose(y, 1.85407, rtol=5e-4) assert ek.allclose(ek.grad(y), 0.847213, rtol=5e-4) ek.disable_flag(ek.JitFlag.RecordLoops)
def test05_side_effect_noloop(pkg): p = get_class(pkg) i = ek.zero(p.Int, 10) j = ek.zero(p.Int, 10) buf = ek.zero(p.Float, 10) ek.disable_flag(ek.JitFlag.RecordLoops) loop = p.Loop(i, j) while loop.cond(i < 10): j += i i += 1 ek.scatter_add(target=buf, value=p.Float(i), index=0, mask=loop.mask()) assert i == p.Int([10] * 10) assert buf == p.Float(550, *([0] * 9)) assert j == p.Int([45] * 10)
def test46_loop_ballistic_2(m): class Ballistic2(ek.CustomOp): def timestep(self, pos, vel, dt=0.02, mu=.1, g=9.81): acc = -mu * vel * ek.norm(vel) - m.Array2f(0, g) pos_out = pos + dt * vel vel_out = vel + dt * acc return pos_out, vel_out def eval(self, pos, vel): pos, vel = m.Array2f(pos), m.Array2f(vel) # Run for 100 iterations it, max_it = m.UInt32(0), 100 loop = m.Loop(pos, vel, it) while loop.cond(it < max_it): # Update loop variables pos_out, vel_out = self.timestep(pos, vel) pos.assign(pos_out) vel.assign(vel_out) it += 1 self.pos = pos self.vel = vel return pos, vel def backward(self): grad_pos, grad_vel = self.grad_out() pos, vel = self.pos, self.vel # Run for 100 iterations it = m.UInt32(0) loop = m.Loop(it, pos, vel, grad_pos, grad_vel) while loop.cond(it < 100): # Take reverse step in time pos_rev, vel_rev = self.timestep(pos, vel, dt=-0.02) pos.assign(pos_rev) vel.assign(vel_rev) # Take a forward step in time, keep track of derivatives ek.enable_grad(pos_rev, vel_rev) pos_fwd, vel_fwd = self.timestep(pos_rev, vel_rev, dt=0.02) ek.set_grad(pos_fwd, grad_pos) ek.set_grad(vel_fwd, grad_vel) ek.enqueue(pos_fwd, vel_fwd) ek.traverse(m.Float, reverse=True) grad_pos.assign(ek.grad(pos_rev)) grad_vel.assign(ek.grad(vel_rev)) it += 1 self.set_grad_in('pos', grad_pos) self.set_grad_in('vel', grad_vel) ek.enable_flag(ek.JitFlag.RecordLoops) pos_in = m.Array2f([1, 2, 4], [1, 2, 1]) vel_in = m.Array2f([10, 9, 4], [5, 3, 6]) for i in range(20): ek.enable_grad(vel_in) ek.eval(vel_in, pos_in) pos_out, vel_out = ek.custom(Ballistic2, pos_in, vel_in) loss = ek.squared_norm(pos_out - m.Array2f(5, 0)) ek.backward(loss) vel_in = m.Array2f(ek.detach(vel_in) - 0.2 * ek.grad(vel_in)) assert ek.allclose(loss, 0, atol=1e-4) assert ek.allclose(vel_in.x, [3.3516, 2.3789, 0.79156], atol=1e-3) ek.disable_flag(ek.JitFlag.RecordLoops)
def test46_loop_ballistic(m): class Ballistic(ek.CustomOp): def timestep(self, pos, vel, dt=0.02, mu=.1, g=9.81): acc = -mu * vel * ek.norm(vel) - m.Array2f(0, g) pos_out = pos + dt * vel vel_out = vel + dt * acc return pos_out, vel_out def eval(self, pos, vel): pos, vel = m.Array2f(pos), m.Array2f(vel) # Run for 100 iterations it, max_it = m.UInt32(0), 100 # Allocate scratch space n = max(ek.width(pos), ek.width(vel)) self.temp_pos = ek.empty(m.Array2f, n * max_it) self.temp_vel = ek.empty(m.Array2f, n * max_it) loop = m.Loop(pos, vel, it) while loop.cond(it < max_it): # Store current loop variables index = it * n + ek.arange(m.UInt32, n) ek.scatter(self.temp_pos, pos, index) ek.scatter(self.temp_vel, vel, index) # Update loop variables pos_out, vel_out = self.timestep(pos, vel) pos.assign(pos_out) vel.assign(vel_out) it += 1 # Ensure output and temp. arrays are evaluated at this point ek.eval(pos, vel) return pos, vel def backward(self): grad_pos, grad_vel = self.grad_out() # Run for 100 iterations it = m.UInt32(100) loop = m.Loop(it, grad_pos, grad_vel) n = ek.width(grad_pos) while loop.cond(it > 0): # Retrieve loop variables, reverse chronological order it -= 1 index = it * n + ek.arange(m.UInt32, n) pos = ek.gather(m.Array2f, self.temp_pos, index) vel = ek.gather(m.Array2f, self.temp_vel, index) # Differentiate loop body in reverse mode ek.enable_grad(pos, vel) pos_out, vel_out = self.timestep(pos, vel) ek.set_grad(pos_out, grad_pos) ek.set_grad(vel_out, grad_vel) ek.enqueue(pos_out, vel_out) ek.traverse(m.Float, reverse=True) # Update loop variables grad_pos.assign(ek.grad(pos)) grad_vel.assign(ek.grad(vel)) self.set_grad_in('pos', grad_pos) self.set_grad_in('vel', grad_vel) pos_in = m.Array2f([1, 2, 4], [1, 2, 1]) vel_in = m.Array2f([10, 9, 4], [5, 3, 6]) ek.enable_flag(ek.JitFlag.RecordLoops) for i in range(20): ek.enable_grad(vel_in) ek.eval(vel_in, pos_in) pos_out, vel_out = ek.custom(Ballistic, pos_in, vel_in) loss = ek.squared_norm(pos_out - m.Array2f(5, 0)) ek.backward(loss) vel_in = m.Array2f(ek.detach(vel_in) - 0.2 * ek.grad(vel_in)) assert ek.allclose(loss, 0, atol=1e-4) assert ek.allclose(vel_in.x, [3.3516, 2.3789, 0.79156], atol=1e-3) ek.disable_flag(ek.JitFlag.RecordLoops)
def teardown_function(function): ek.disable_flag(ek.JitFlag.RecordLoops)