예제 #1
0
        def backward(self):
            grad_pos, grad_vel = self.grad_out()

            # Run for 100 iterations
            it = m.UInt32(100)

            loop = m.Loop(it, grad_pos, grad_vel)
            n = ek.width(grad_pos)
            while loop.cond(it > 0):
                # Retrieve loop variables, reverse chronological order
                it -= 1
                index = it * n + ek.arange(m.UInt32, n)
                pos = ek.gather(m.Array2f, self.temp_pos, index)
                vel = ek.gather(m.Array2f, self.temp_vel, index)

                # Differentiate loop body in reverse mode
                ek.enable_grad(pos, vel)
                pos_out, vel_out = self.timestep(pos, vel)
                ek.set_grad(pos_out, grad_pos)
                ek.set_grad(vel_out, grad_vel)
                ek.enqueue(pos_out, vel_out)
                ek.traverse(m.Float, reverse=True)

                # Update loop variables
                grad_pos.assign(ek.grad(pos))
                grad_vel.assign(ek.grad(vel))

            self.set_grad_in('pos', grad_pos)
            self.set_grad_in('vel', grad_vel)
예제 #2
0
        def backward(self):
            grad_pos, grad_vel = self.grad_out()
            pos, vel = self.pos, self.vel

            # Run for 100 iterations
            it = m.UInt32(0)

            loop = m.Loop(it, pos, vel, grad_pos, grad_vel)
            while loop.cond(it < 100):
                # Take reverse step in time
                pos_rev, vel_rev = self.timestep(pos, vel, dt=-0.02)
                pos.assign(pos_rev)
                vel.assign(vel_rev)

                # Take a forward step in time, keep track of derivatives
                ek.enable_grad(pos_rev, vel_rev)
                pos_fwd, vel_fwd = self.timestep(pos_rev, vel_rev, dt=0.02)
                ek.set_grad(pos_fwd, grad_pos)
                ek.set_grad(vel_fwd, grad_vel)
                ek.enqueue(pos_fwd, vel_fwd)
                ek.traverse(m.Float, reverse=True)

                grad_pos.assign(ek.grad(pos_rev))
                grad_vel.assign(ek.grad(vel_rev))
                it += 1

            self.set_grad_in('pos', grad_pos)
            self.set_grad_in('vel', grad_vel)
예제 #3
0
        def backward(self):
            grad_pos, grad_vel = self.grad_out()
            pos, vel = self.pos, self.vel

            # Run for 100 iterations
            it = m.UInt32(0)

            loop = m.Loop("backward", lambda:
                          (it, pos, vel, grad_pos, grad_vel))
            while loop(it < 100):
                # Take reverse step in time
                pos, vel = self.timestep(pos, vel, dt=-0.02)

                # Take a forward step in time, keep track of derivatives
                pos_rev, vel_rev = m.Array2f(pos), m.Array2f(vel)
                ek.enable_grad(pos_rev, vel_rev)
                pos_fwd, vel_fwd = self.timestep(pos_rev, vel_rev, dt=0.02)

                ek.set_grad(pos_fwd, grad_pos)
                ek.set_grad(vel_fwd, grad_vel)
                ek.enqueue(pos_fwd, vel_fwd)
                ek.traverse(m.Float, reverse=True)

                grad_pos = ek.grad(pos_rev)
                grad_vel = ek.grad(vel_rev)
                it += 1

            self.set_grad_in('pos', grad_pos)
            self.set_grad_in('vel', grad_vel)
예제 #4
0
def test52_scatter_fwd_eager(m):
    with EagerMode():
        x = m.Float(4.0)
        ek.enable_grad(x)
        ek.set_grad(x, 1)

        values = x * x * ek.linspace(m.Float, 1, 4, 4)
        idx = 2 * ek.arange(m.UInt32, 4)

        buf = ek.zero(m.Float, 10)
        ek.scatter(buf, values, idx)

        assert ek.grad_enabled(buf)

        ref = [16.0, 0.0, 32.0, 0.0, 48.0, 0.0, 64.0, 0.0, 0.0, 0.0]
        assert ek.allclose(buf, ref)

        grad = ek.grad(buf)

        ref_grad = [8.0, 0.0, 16.0, 0.0, 24.0, 0.0, 32.0, 0.0, 0.0, 0.0]
        assert ek.allclose(grad, ref_grad)

        # Overwrite first value with non-diff value, resulting gradient entry should be 0
        y = m.Float(3)
        idx = m.UInt32(0)
        ek.scatter(buf, y, idx)

        ref = [3.0, 0.0, 32.0, 0.0, 48.0, 0.0, 64.0, 0.0, 0.0, 0.0]
        assert ek.allclose(buf, ref)

        ek.forward(x)
        grad = ek.grad(buf)

        ref_grad = [0.0, 0.0, 16.0, 0.0, 24.0, 0.0, 32.0, 0.0, 0.0, 0.0]
        assert ek.allclose(grad, ref_grad)
예제 #5
0
def test49_eager_fwd(m):
    with EagerMode():
        x = m.Float(1)
        ek.enable_grad(x)
        ek.set_grad(x, 10)
        y = x * x * x
        assert ek.grad(y) == 30
예제 #6
0
def test43_custom_reverse(m):
    d = m.Array3f(1, 2, 3)
    ek.enable_grad(d)
    d2 = ek.custom(Normalize, d)
    ek.set_grad(d2, m.Array3f(5, 6, 7))
    ek.enqueue(d2)
    ek.traverse(m.Float, reverse=True)
    assert ek.allclose(ek.grad(d), m.Array3f(0.610883, 0.152721, -0.305441))
예제 #7
0
def test50_gather_fwd_eager(m):
    with EagerMode():
        x = ek.linspace(m.Float, -1, 1, 10)
        ek.enable_grad(x)
        ek.set_grad(x, 1)
        y = ek.gather(m.Float, x * x, m.UInt(1, 1, 2, 3))
        ref = [-1.55556, -1.55556, -1.11111, -0.666667]
        assert ek.allclose(ek.grad(y), ref)
예제 #8
0
def test44_custom_forward(m):
    d = m.Array3f(1, 2, 3)
    ek.enable_grad(d)
    d2 = ek.custom(Normalize, d)
    ek.set_grad(d, m.Array3f(5, 6, 7))
    ek.enqueue(d)
    ek.traverse(m.Float, reverse=False, retain_graph=True)
    assert ek.grad(d) == 0
    ek.set_grad(d, m.Array3f(5, 6, 7))
    assert ek.allclose(ek.grad(d2), m.Array3f(0.610883, 0.152721, -0.305441))
    ek.enqueue(d)
    ek.traverse(m.Float, reverse=False, retain_graph=False)
    assert ek.allclose(ek.grad(d2),
                       m.Array3f(0.610883, 0.152721, -0.305441) * 2)
예제 #9
0
def test_ad_operations(package):
    Float, Array3f = package.Float, package.Array3f
    prepare(package)

    class MyStruct:
        ENOKI_STRUCT = {'a': Array3f, 'b': Float}

        def __init__(self):
            self.a = Array3f()
            self.b = Float()

    foo = ek.zero(MyStruct, 4)
    assert not ek.grad_enabled(foo.a)
    assert not ek.grad_enabled(foo.b)
    assert not ek.grad_enabled(foo)

    ek.enable_grad(foo)
    assert ek.grad_enabled(foo.a)
    assert ek.grad_enabled(foo.b)
    assert ek.grad_enabled(foo)

    foo_detached = ek.detach(foo)
    assert not ek.grad_enabled(foo_detached.a)
    assert not ek.grad_enabled(foo_detached.b)
    assert not ek.grad_enabled(foo_detached)

    x = Float(4.0)
    ek.enable_grad(x)
    foo.a += x
    foo.b += x * x
    ek.forward(x)
    foo_grad = ek.grad(foo)
    assert foo_grad.a == 1
    assert foo_grad.b == 8

    ek.set_grad(foo, 5.0)
    foo_grad = ek.grad(foo)
    assert foo_grad.a == 5.0
    assert foo_grad.b == 5.0

    ek.accum_grad(foo, 5.0)
    foo_grad = ek.grad(foo)
    assert foo_grad.a == 10.0
    assert foo_grad.b == 10.0
예제 #10
0
def test51_scatter_reduce_fwd_eager(m):
    with EagerMode():
        for i in range(3):
            idx1 = ek.arange(m.UInt, 5)
            idx2 = ek.arange(m.UInt, 4) + 3

            x = ek.linspace(m.Float, 0, 1, 5)
            y = ek.linspace(m.Float, 1, 2, 4)
            buf = ek.zero(m.Float, 10)

            if i % 2 == 0:
                ek.enable_grad(buf)
                ek.set_grad(buf, 1)
            if i // 2 == 0:
                ek.enable_grad(x, y)
                ek.set_grad(x, 1)
                ek.set_grad(y, 1)

            x.label = "x"
            y.label = "y"
            buf.label = "buf"

            buf2 = m.Float(buf)
            ek.scatter_reduce(ek.ReduceOp.Add, buf2, x, idx1)
            ek.scatter_reduce(ek.ReduceOp.Add, buf2, y, idx2)

            s = ek.dot_async(buf2, buf2)

            # Verified against Mathematica
            assert ek.allclose(ek.detach(s), 15.5972)
            assert ek.allclose(ek.grad(s), (25.1667 if i // 2 == 0 else 0) +
                               (17 if i % 2 == 0 else 0))
예제 #11
0
def test53_scatter_fwd_permute_eager(m):
    with EagerMode():
        x = m.Float(4.0)
        ek.enable_grad(x)
        ek.set_grad(x, 1)

        values_0 = x * ek.linspace(m.Float, 1, 9, 5)
        values_1 = x * ek.linspace(m.Float, 11, 19, 5)

        buf = ek.zero(m.Float, 10)

        idx_0 = ek.arange(m.UInt32, 5)
        idx_1 = ek.arange(m.UInt32, 5) + 5

        ek.scatter(buf, values_0, idx_0, permute=False)
        ek.scatter(buf, values_1, idx_1, permute=False)

        ref = [4.0, 12.0, 20.0, 28.0, 36.0, 44.0, 52.0, 60.0, 68.0, 76.0]
        assert ek.allclose(buf, ref)

        grad = ek.grad(buf)

        ref_grad = [1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0]
        assert ek.allclose(grad, ref_grad)
예제 #12
0
def test21_scatter_add_fwd(m):
    for i in range(3):
        idx1 = ek.arange(m.UInt, 5)
        idx2 = ek.arange(m.UInt, 4) + 3

        x = ek.linspace(m.Float, 0, 1, 5)
        y = ek.linspace(m.Float, 1, 2, 4)
        buf = ek.zero(m.Float, 10)

        if i % 2 == 0:
            ek.enable_grad(buf)
            ek.set_grad(buf, 1)
        if i // 2 == 0:
            ek.enable_grad(x, y)
            ek.set_grad(x, 1)
            ek.set_grad(y, 1)

        x.label = "x"
        y.label = "y"
        buf.label = "buf"

        buf2 = m.Float(buf)
        ek.scatter_add(buf2, x, idx1)
        ek.scatter_add(buf2, y, idx2)

        s = ek.dot_async(buf2, buf2)

        if i % 2 == 0:
            ek.enqueue(buf)
        if i // 2 == 0:
            ek.enqueue(x, y)

        ek.traverse(m.Float, reverse=False)

        # Verified against Mathematica
        assert ek.allclose(ek.detach(s), 15.5972)
        assert ek.allclose(ek.grad(s), (25.1667 if i // 2 == 0 else 0) +
                           (17 if i % 2 == 0 else 0))
예제 #13
0
def test02_add_fwd(m):
    if True:
        a, b = m.Float(1), m.Float(2)
        ek.enable_grad(a, b)
        c = 2 * a + b
        ek.forward(a, retain_graph=True)
        assert ek.grad(c) == 2
        ek.set_grad(c, 101)
        ek.forward(b)
        assert ek.grad(c) == 102

    if True:
        a, b = m.Float(1), m.Float(2)
        ek.enable_grad(a, b)
        c = 2 * a + b
        ek.set_grad(a, 1.0)
        ek.enqueue(a)
        ek.traverse(m.Float, retain_graph=True, reverse=False)
        assert ek.grad(c) == 2
        assert ek.grad(a) == 0
        ek.set_grad(a, 1.0)
        ek.enqueue(a)
        ek.traverse(m.Float, retain_graph=False, reverse=False)
        assert ek.grad(c) == 4