def rcp_(a0): if not a0.IsFloat: raise Exception("rcp(): requires floating point operands!") ar, sr = _check1(a0) if not a0.IsSpecial: for i in range(sr): ar[i] = _ek.rcp(a0[i]) elif a0.IsComplex or a0.IsQuaternion: return _ek.conj(a0) * _ek.rcp(_ek.squared_norm(a0)) else: raise Exception('rcp(): unsupported array type!') return ar
def log2_(a0): if not a0.IsFloat: raise Exception("log2(): requires floating point operands!") ar, sr = _check1(a0) if not a0.IsSpecial: for i in range(sr): ar[i] = _ek.log2(a0[i]) elif a0.IsComplex: ar.real = .5 * _ek.log2(_ek.squared_norm(a0)) ar.imag = _ek.arg(a0) * _ek.InvLogTwo else: raise Exception("log2(): unsupported array type!") return ar
def log_(a0): if not a0.IsFloat: raise Exception("log(): requires floating point operands!") ar, sr = _check1(a0) if not a0.IsSpecial: for i in range(sr): ar[i] = _ek.log(a0[i]) elif a0.IsComplex: ar.real = .5 * _ek.log(_ek.squared_norm(a0)) ar.imag = _ek.arg(a0) elif a0.IsQuaternion: qi_n = _ek.normalize(a0.imag) rq = _ek.norm(a0) acos_rq = _ek.acos(a0.real / rq) log_rq = _ek.log(rq) ar.imag = qi_n * acos_rq ar.real = log_rq else: raise Exception("log(): unsupported array type!") return ar
def test46_loop_ballistic_2(m): class Ballistic2(ek.CustomOp): def timestep(self, pos, vel, dt=0.02, mu=.1, g=9.81): acc = -mu * vel * ek.norm(vel) - m.Array2f(0, g) pos_out = pos + dt * vel vel_out = vel + dt * acc return pos_out, vel_out def eval(self, pos, vel): pos, vel = m.Array2f(pos), m.Array2f(vel) # Run for 100 iterations it, max_it = m.UInt32(0), 100 loop = m.Loop(pos, vel, it) while loop.cond(it < max_it): # Update loop variables pos_out, vel_out = self.timestep(pos, vel) pos.assign(pos_out) vel.assign(vel_out) it += 1 self.pos = pos self.vel = vel return pos, vel def backward(self): grad_pos, grad_vel = self.grad_out() pos, vel = self.pos, self.vel # Run for 100 iterations it = m.UInt32(0) loop = m.Loop(it, pos, vel, grad_pos, grad_vel) while loop.cond(it < 100): # Take reverse step in time pos_rev, vel_rev = self.timestep(pos, vel, dt=-0.02) pos.assign(pos_rev) vel.assign(vel_rev) # Take a forward step in time, keep track of derivatives ek.enable_grad(pos_rev, vel_rev) pos_fwd, vel_fwd = self.timestep(pos_rev, vel_rev, dt=0.02) ek.set_grad(pos_fwd, grad_pos) ek.set_grad(vel_fwd, grad_vel) ek.enqueue(pos_fwd, vel_fwd) ek.traverse(m.Float, reverse=True) grad_pos.assign(ek.grad(pos_rev)) grad_vel.assign(ek.grad(vel_rev)) it += 1 self.set_grad_in('pos', grad_pos) self.set_grad_in('vel', grad_vel) ek.enable_flag(ek.JitFlag.RecordLoops) pos_in = m.Array2f([1, 2, 4], [1, 2, 1]) vel_in = m.Array2f([10, 9, 4], [5, 3, 6]) for i in range(20): ek.enable_grad(vel_in) ek.eval(vel_in, pos_in) pos_out, vel_out = ek.custom(Ballistic2, pos_in, vel_in) loss = ek.squared_norm(pos_out - m.Array2f(5, 0)) ek.backward(loss) vel_in = m.Array2f(ek.detach(vel_in) - 0.2 * ek.grad(vel_in)) assert ek.allclose(loss, 0, atol=1e-4) assert ek.allclose(vel_in.x, [3.3516, 2.3789, 0.79156], atol=1e-3) ek.disable_flag(ek.JitFlag.RecordLoops)
def test46_loop_ballistic(m): class Ballistic(ek.CustomOp): def timestep(self, pos, vel, dt=0.02, mu=.1, g=9.81): acc = -mu * vel * ek.norm(vel) - m.Array2f(0, g) pos_out = pos + dt * vel vel_out = vel + dt * acc return pos_out, vel_out def eval(self, pos, vel): pos, vel = m.Array2f(pos), m.Array2f(vel) # Run for 100 iterations it, max_it = m.UInt32(0), 100 # Allocate scratch space n = max(ek.width(pos), ek.width(vel)) self.temp_pos = ek.empty(m.Array2f, n * max_it) self.temp_vel = ek.empty(m.Array2f, n * max_it) loop = m.Loop(pos, vel, it) while loop.cond(it < max_it): # Store current loop variables index = it * n + ek.arange(m.UInt32, n) ek.scatter(self.temp_pos, pos, index) ek.scatter(self.temp_vel, vel, index) # Update loop variables pos_out, vel_out = self.timestep(pos, vel) pos.assign(pos_out) vel.assign(vel_out) it += 1 # Ensure output and temp. arrays are evaluated at this point ek.eval(pos, vel) return pos, vel def backward(self): grad_pos, grad_vel = self.grad_out() # Run for 100 iterations it = m.UInt32(100) loop = m.Loop(it, grad_pos, grad_vel) n = ek.width(grad_pos) while loop.cond(it > 0): # Retrieve loop variables, reverse chronological order it -= 1 index = it * n + ek.arange(m.UInt32, n) pos = ek.gather(m.Array2f, self.temp_pos, index) vel = ek.gather(m.Array2f, self.temp_vel, index) # Differentiate loop body in reverse mode ek.enable_grad(pos, vel) pos_out, vel_out = self.timestep(pos, vel) ek.set_grad(pos_out, grad_pos) ek.set_grad(vel_out, grad_vel) ek.enqueue(pos_out, vel_out) ek.traverse(m.Float, reverse=True) # Update loop variables grad_pos.assign(ek.grad(pos)) grad_vel.assign(ek.grad(vel)) self.set_grad_in('pos', grad_pos) self.set_grad_in('vel', grad_vel) pos_in = m.Array2f([1, 2, 4], [1, 2, 1]) vel_in = m.Array2f([10, 9, 4], [5, 3, 6]) ek.enable_flag(ek.JitFlag.RecordLoops) for i in range(20): ek.enable_grad(vel_in) ek.eval(vel_in, pos_in) pos_out, vel_out = ek.custom(Ballistic, pos_in, vel_in) loss = ek.squared_norm(pos_out - m.Array2f(5, 0)) ek.backward(loss) vel_in = m.Array2f(ek.detach(vel_in) - 0.2 * ek.grad(vel_in)) assert ek.allclose(loss, 0, atol=1e-4) assert ek.allclose(vel_in.x, [3.3516, 2.3789, 0.79156], atol=1e-3) ek.disable_flag(ek.JitFlag.RecordLoops)