Ejemplo n.º 1
0
        def backward(self):
            grad_pos, grad_vel = self.grad_out()

            # Run for 100 iterations
            it = m.UInt32(100)

            loop = m.Loop(it, grad_pos, grad_vel)
            n = ek.width(grad_pos)
            while loop.cond(it > 0):
                # Retrieve loop variables, reverse chronological order
                it -= 1
                index = it * n + ek.arange(m.UInt32, n)
                pos = ek.gather(m.Array2f, self.temp_pos, index)
                vel = ek.gather(m.Array2f, self.temp_vel, index)

                # Differentiate loop body in reverse mode
                ek.enable_grad(pos, vel)
                pos_out, vel_out = self.timestep(pos, vel)
                ek.set_grad(pos_out, grad_pos)
                ek.set_grad(vel_out, grad_vel)
                ek.enqueue(pos_out, vel_out)
                ek.traverse(m.Float, reverse=True)

                # Update loop variables
                grad_pos.assign(ek.grad(pos))
                grad_vel.assign(ek.grad(vel))

            self.set_grad_in('pos', grad_pos)
            self.set_grad_in('vel', grad_vel)
Ejemplo n.º 2
0
        def backward(self):
            grad_pos, grad_vel = self.grad_out()
            pos, vel = self.pos, self.vel

            # Run for 100 iterations
            it = m.UInt32(0)

            loop = m.Loop(it, pos, vel, grad_pos, grad_vel)
            while loop.cond(it < 100):
                # Take reverse step in time
                pos_rev, vel_rev = self.timestep(pos, vel, dt=-0.02)
                pos.assign(pos_rev)
                vel.assign(vel_rev)

                # Take a forward step in time, keep track of derivatives
                ek.enable_grad(pos_rev, vel_rev)
                pos_fwd, vel_fwd = self.timestep(pos_rev, vel_rev, dt=0.02)
                ek.set_grad(pos_fwd, grad_pos)
                ek.set_grad(vel_fwd, grad_vel)
                ek.enqueue(pos_fwd, vel_fwd)
                ek.traverse(m.Float, reverse=True)

                grad_pos.assign(ek.grad(pos_rev))
                grad_vel.assign(ek.grad(vel_rev))
                it += 1

            self.set_grad_in('pos', grad_pos)
            self.set_grad_in('vel', grad_vel)
Ejemplo n.º 3
0
        def backward(self):
            grad_pos, grad_vel = self.grad_out()
            pos, vel = self.pos, self.vel

            # Run for 100 iterations
            it = m.UInt32(0)

            loop = m.Loop("backward", lambda:
                          (it, pos, vel, grad_pos, grad_vel))
            while loop(it < 100):
                # Take reverse step in time
                pos, vel = self.timestep(pos, vel, dt=-0.02)

                # Take a forward step in time, keep track of derivatives
                pos_rev, vel_rev = m.Array2f(pos), m.Array2f(vel)
                ek.enable_grad(pos_rev, vel_rev)
                pos_fwd, vel_fwd = self.timestep(pos_rev, vel_rev, dt=0.02)

                ek.set_grad(pos_fwd, grad_pos)
                ek.set_grad(vel_fwd, grad_vel)
                ek.enqueue(pos_fwd, vel_fwd)
                ek.traverse(m.Float, reverse=True)

                grad_pos = ek.grad(pos_rev)
                grad_vel = ek.grad(vel_rev)
                it += 1

            self.set_grad_in('pos', grad_pos)
            self.set_grad_in('vel', grad_vel)
Ejemplo n.º 4
0
def test43_custom_reverse(m):
    d = m.Array3f(1, 2, 3)
    ek.enable_grad(d)
    d2 = ek.custom(Normalize, d)
    ek.set_grad(d2, m.Array3f(5, 6, 7))
    ek.enqueue(d2)
    ek.traverse(m.Float, reverse=True)
    assert ek.allclose(ek.grad(d), m.Array3f(0.610883, 0.152721, -0.305441))
Ejemplo n.º 5
0
def test44_custom_forward(m):
    d = m.Array3f(1, 2, 3)
    ek.enable_grad(d)
    d2 = ek.custom(Normalize, d)
    ek.set_grad(d, m.Array3f(5, 6, 7))
    ek.enqueue(d)
    ek.traverse(m.Float, reverse=False, retain_graph=True)
    assert ek.grad(d) == 0
    ek.set_grad(d, m.Array3f(5, 6, 7))
    assert ek.allclose(ek.grad(d2), m.Array3f(0.610883, 0.152721, -0.305441))
    ek.enqueue(d)
    ek.traverse(m.Float, reverse=False, retain_graph=False)
    assert ek.allclose(ek.grad(d2),
                       m.Array3f(0.610883, 0.152721, -0.305441) * 2)
Ejemplo n.º 6
0
def test21_scatter_add_fwd(m):
    for i in range(3):
        idx1 = ek.arange(m.UInt, 5)
        idx2 = ek.arange(m.UInt, 4) + 3

        x = ek.linspace(m.Float, 0, 1, 5)
        y = ek.linspace(m.Float, 1, 2, 4)
        buf = ek.zero(m.Float, 10)

        if i % 2 == 0:
            ek.enable_grad(buf)
            ek.set_grad(buf, 1)
        if i // 2 == 0:
            ek.enable_grad(x, y)
            ek.set_grad(x, 1)
            ek.set_grad(y, 1)

        x.label = "x"
        y.label = "y"
        buf.label = "buf"

        buf2 = m.Float(buf)
        ek.scatter_add(buf2, x, idx1)
        ek.scatter_add(buf2, y, idx2)

        s = ek.dot_async(buf2, buf2)

        if i % 2 == 0:
            ek.enqueue(buf)
        if i // 2 == 0:
            ek.enqueue(x, y)

        ek.traverse(m.Float, reverse=False)

        # Verified against Mathematica
        assert ek.allclose(ek.detach(s), 15.5972)
        assert ek.allclose(ek.grad(s), (25.1667 if i // 2 == 0 else 0) +
                           (17 if i % 2 == 0 else 0))
Ejemplo n.º 7
0
def test02_add_fwd(m):
    if True:
        a, b = m.Float(1), m.Float(2)
        ek.enable_grad(a, b)
        c = 2 * a + b
        ek.forward(a, retain_graph=True)
        assert ek.grad(c) == 2
        ek.set_grad(c, 101)
        ek.forward(b)
        assert ek.grad(c) == 102

    if True:
        a, b = m.Float(1), m.Float(2)
        ek.enable_grad(a, b)
        c = 2 * a + b
        ek.set_grad(a, 1.0)
        ek.enqueue(a)
        ek.traverse(m.Float, retain_graph=True, reverse=False)
        assert ek.grad(c) == 2
        assert ek.grad(a) == 0
        ek.set_grad(a, 1.0)
        ek.enqueue(a)
        ek.traverse(m.Float, retain_graph=False, reverse=False)
        assert ek.grad(c) == 4