Esempio n. 1
0
def test45_diff_loop(m):
    def mcint(a, b, f, sample_count=100000):
        rng = m.PCG32()
        i = m.UInt32(0)
        result = m.Float(0)
        l = m.Loop(i, rng, result)
        while l.cond(i < sample_count):
            result += f(ek.lerp(a, b, rng.next_float32()))
            i += 1
        return result * (b - a) / sample_count

    class EllipticK(ek.CustomOp):
        # --- Internally used utility methods ---

        # Integrand of the 'K' function
        def K(self, x, m_):
            return ek.rsqrt(1 - m_ * ek.sqr(ek.sin(x)))

        # Derivative of the above with respect to 'm'
        def dK(self, x, m_):
            m_ = m.Float(m_)  # Convert 'm' to differentiable type
            ek.enable_grad(m_)
            y = self.K(x, m_)
            ek.forward(m_)
            return ek.grad(y)

        # Monte Carlo integral of dK, used in forward/reverse pass
        def eval_grad(self):
            return mcint(a=0, b=ek.Pi / 2, f=lambda x: self.dK(x, self.m_))

        # --- CustomOp interface ---

        def eval(self, m_):
            self.m_ = m_  # Stash 'm' for later
            return mcint(a=0, b=ek.Pi / 2, f=lambda x: self.K(x, self.m_))

        def forward(self):
            self.set_grad_out(self.grad_in('m_') * self.eval_grad())

    def elliptic_k(m_):
        return ek.custom(EllipticK, m_)

    ek.enable_flag(ek.JitFlag.RecordLoops)
    x = m.Float(0.5)
    ek.enable_grad(x)
    y = elliptic_k(x)
    ek.forward(x)
    assert ek.allclose(y, 1.85407, rtol=5e-4)
    assert ek.allclose(ek.grad(y), 0.847213, rtol=5e-4)
    ek.disable_flag(ek.JitFlag.RecordLoops)
Esempio n. 2
0
def test05_side_effect_noloop(pkg):
    p = get_class(pkg)

    i = ek.zero(p.Int, 10)
    j = ek.zero(p.Int, 10)
    buf = ek.zero(p.Float, 10)
    ek.disable_flag(ek.JitFlag.RecordLoops)

    loop = p.Loop(i, j)
    while loop.cond(i < 10):
        j += i
        i += 1
        ek.scatter_add(target=buf, value=p.Float(i), index=0, mask=loop.mask())

    assert i == p.Int([10] * 10)
    assert buf == p.Float(550, *([0] * 9))
    assert j == p.Int([45] * 10)
Esempio n. 3
0
def test46_loop_ballistic_2(m):
    class Ballistic2(ek.CustomOp):
        def timestep(self, pos, vel, dt=0.02, mu=.1, g=9.81):
            acc = -mu * vel * ek.norm(vel) - m.Array2f(0, g)
            pos_out = pos + dt * vel
            vel_out = vel + dt * acc
            return pos_out, vel_out

        def eval(self, pos, vel):
            pos, vel = m.Array2f(pos), m.Array2f(vel)

            # Run for 100 iterations
            it, max_it = m.UInt32(0), 100

            loop = m.Loop(pos, vel, it)
            while loop.cond(it < max_it):
                # Update loop variables
                pos_out, vel_out = self.timestep(pos, vel)
                pos.assign(pos_out)
                vel.assign(vel_out)

                it += 1

            self.pos = pos
            self.vel = vel

            return pos, vel

        def backward(self):
            grad_pos, grad_vel = self.grad_out()
            pos, vel = self.pos, self.vel

            # Run for 100 iterations
            it = m.UInt32(0)

            loop = m.Loop(it, pos, vel, grad_pos, grad_vel)
            while loop.cond(it < 100):
                # Take reverse step in time
                pos_rev, vel_rev = self.timestep(pos, vel, dt=-0.02)
                pos.assign(pos_rev)
                vel.assign(vel_rev)

                # Take a forward step in time, keep track of derivatives
                ek.enable_grad(pos_rev, vel_rev)
                pos_fwd, vel_fwd = self.timestep(pos_rev, vel_rev, dt=0.02)
                ek.set_grad(pos_fwd, grad_pos)
                ek.set_grad(vel_fwd, grad_vel)
                ek.enqueue(pos_fwd, vel_fwd)
                ek.traverse(m.Float, reverse=True)

                grad_pos.assign(ek.grad(pos_rev))
                grad_vel.assign(ek.grad(vel_rev))
                it += 1

            self.set_grad_in('pos', grad_pos)
            self.set_grad_in('vel', grad_vel)

    ek.enable_flag(ek.JitFlag.RecordLoops)
    pos_in = m.Array2f([1, 2, 4], [1, 2, 1])
    vel_in = m.Array2f([10, 9, 4], [5, 3, 6])

    for i in range(20):
        ek.enable_grad(vel_in)
        ek.eval(vel_in, pos_in)
        pos_out, vel_out = ek.custom(Ballistic2, pos_in, vel_in)
        loss = ek.squared_norm(pos_out - m.Array2f(5, 0))
        ek.backward(loss)

        vel_in = m.Array2f(ek.detach(vel_in) - 0.2 * ek.grad(vel_in))

    assert ek.allclose(loss, 0, atol=1e-4)
    assert ek.allclose(vel_in.x, [3.3516, 2.3789, 0.79156], atol=1e-3)
    ek.disable_flag(ek.JitFlag.RecordLoops)
Esempio n. 4
0
def test46_loop_ballistic(m):
    class Ballistic(ek.CustomOp):
        def timestep(self, pos, vel, dt=0.02, mu=.1, g=9.81):
            acc = -mu * vel * ek.norm(vel) - m.Array2f(0, g)
            pos_out = pos + dt * vel
            vel_out = vel + dt * acc
            return pos_out, vel_out

        def eval(self, pos, vel):
            pos, vel = m.Array2f(pos), m.Array2f(vel)

            # Run for 100 iterations
            it, max_it = m.UInt32(0), 100

            # Allocate scratch space
            n = max(ek.width(pos), ek.width(vel))
            self.temp_pos = ek.empty(m.Array2f, n * max_it)
            self.temp_vel = ek.empty(m.Array2f, n * max_it)

            loop = m.Loop(pos, vel, it)
            while loop.cond(it < max_it):
                # Store current loop variables
                index = it * n + ek.arange(m.UInt32, n)
                ek.scatter(self.temp_pos, pos, index)
                ek.scatter(self.temp_vel, vel, index)

                # Update loop variables
                pos_out, vel_out = self.timestep(pos, vel)
                pos.assign(pos_out)
                vel.assign(vel_out)

                it += 1

            # Ensure output and temp. arrays are evaluated at this point
            ek.eval(pos, vel)

            return pos, vel

        def backward(self):
            grad_pos, grad_vel = self.grad_out()

            # Run for 100 iterations
            it = m.UInt32(100)

            loop = m.Loop(it, grad_pos, grad_vel)
            n = ek.width(grad_pos)
            while loop.cond(it > 0):
                # Retrieve loop variables, reverse chronological order
                it -= 1
                index = it * n + ek.arange(m.UInt32, n)
                pos = ek.gather(m.Array2f, self.temp_pos, index)
                vel = ek.gather(m.Array2f, self.temp_vel, index)

                # Differentiate loop body in reverse mode
                ek.enable_grad(pos, vel)
                pos_out, vel_out = self.timestep(pos, vel)
                ek.set_grad(pos_out, grad_pos)
                ek.set_grad(vel_out, grad_vel)
                ek.enqueue(pos_out, vel_out)
                ek.traverse(m.Float, reverse=True)

                # Update loop variables
                grad_pos.assign(ek.grad(pos))
                grad_vel.assign(ek.grad(vel))

            self.set_grad_in('pos', grad_pos)
            self.set_grad_in('vel', grad_vel)

    pos_in = m.Array2f([1, 2, 4], [1, 2, 1])
    vel_in = m.Array2f([10, 9, 4], [5, 3, 6])

    ek.enable_flag(ek.JitFlag.RecordLoops)
    for i in range(20):
        ek.enable_grad(vel_in)
        ek.eval(vel_in, pos_in)
        pos_out, vel_out = ek.custom(Ballistic, pos_in, vel_in)
        loss = ek.squared_norm(pos_out - m.Array2f(5, 0))
        ek.backward(loss)

        vel_in = m.Array2f(ek.detach(vel_in) - 0.2 * ek.grad(vel_in))

    assert ek.allclose(loss, 0, atol=1e-4)
    assert ek.allclose(vel_in.x, [3.3516, 2.3789, 0.79156], atol=1e-3)
    ek.disable_flag(ek.JitFlag.RecordLoops)
Esempio n. 5
0
def teardown_function(function):
    ek.disable_flag(ek.JitFlag.RecordLoops)