コード例 #1
0
def test22_scatter_fwd(m):
    x = m.Float(4.0)
    ek.enable_grad(x)

    values = x * x * ek.linspace(m.Float, 1, 4, 4)
    idx = 2 * ek.arange(m.UInt32, 4)

    buf = ek.zero(m.Float, 10)
    ek.scatter(buf, values, idx)

    assert ek.grad_enabled(buf)

    ref = [16.0, 0.0, 32.0, 0.0, 48.0, 0.0, 64.0, 0.0, 0.0, 0.0]
    assert ek.allclose(buf, ref)

    ek.forward(x, retain_graph=True)
    grad = ek.grad(buf)

    ref_grad = [8.0, 0.0, 16.0, 0.0, 24.0, 0.0, 32.0, 0.0, 0.0, 0.0]
    assert ek.allclose(grad, ref_grad)

    # Overwrite first value with non-diff value, resulting gradient entry should be 0
    y = m.Float(3)
    idx = m.UInt32(0)
    ek.scatter(buf, y, idx)

    ref = [3.0, 0.0, 32.0, 0.0, 48.0, 0.0, 64.0, 0.0, 0.0, 0.0]
    assert ek.allclose(buf, ref)

    ek.forward(x)
    grad = ek.grad(buf)

    ref_grad = [0.0, 0.0, 16.0, 0.0, 24.0, 0.0, 32.0, 0.0, 0.0, 0.0]
    assert ek.allclose(grad, ref_grad)
コード例 #2
0
        def backward(self):
            grad_pos, grad_vel = self.grad_out()
            pos, vel = self.pos, self.vel

            # Run for 100 iterations
            it = m.UInt32(0)

            loop = m.Loop(it, pos, vel, grad_pos, grad_vel)
            while loop.cond(it < 100):
                # Take reverse step in time
                pos_rev, vel_rev = self.timestep(pos, vel, dt=-0.02)
                pos.assign(pos_rev)
                vel.assign(vel_rev)

                # Take a forward step in time, keep track of derivatives
                ek.enable_grad(pos_rev, vel_rev)
                pos_fwd, vel_fwd = self.timestep(pos_rev, vel_rev, dt=0.02)
                ek.set_grad(pos_fwd, grad_pos)
                ek.set_grad(vel_fwd, grad_vel)
                ek.enqueue(pos_fwd, vel_fwd)
                ek.traverse(m.Float, reverse=True)

                grad_pos.assign(ek.grad(pos_rev))
                grad_vel.assign(ek.grad(vel_rev))
                it += 1

            self.set_grad_in('pos', grad_pos)
            self.set_grad_in('vel', grad_vel)
コード例 #3
0
def test01_add_rev(m):
    a, b = m.Float(1), m.Float(2)
    ek.enable_grad(a, b)
    c = 2 * a + b
    ek.backward(c)
    assert ek.grad(a) == 2
    assert ek.grad(b) == 1
コード例 #4
0
        def backward(self):
            grad_pos, grad_vel = self.grad_out()
            pos, vel = self.pos, self.vel

            # Run for 100 iterations
            it = m.UInt32(0)

            loop = m.Loop("backward", lambda:
                          (it, pos, vel, grad_pos, grad_vel))
            while loop(it < 100):
                # Take reverse step in time
                pos, vel = self.timestep(pos, vel, dt=-0.02)

                # Take a forward step in time, keep track of derivatives
                pos_rev, vel_rev = m.Array2f(pos), m.Array2f(vel)
                ek.enable_grad(pos_rev, vel_rev)
                pos_fwd, vel_fwd = self.timestep(pos_rev, vel_rev, dt=0.02)

                ek.set_grad(pos_fwd, grad_pos)
                ek.set_grad(vel_fwd, grad_vel)
                ek.enqueue(pos_fwd, vel_fwd)
                ek.traverse(m.Float, reverse=True)

                grad_pos = ek.grad(pos_rev)
                grad_vel = ek.grad(vel_rev)
                it += 1

            self.set_grad_in('pos', grad_pos)
            self.set_grad_in('vel', grad_vel)
コード例 #5
0
def test04_div(m):
    a, b = m.Float(2), m.Float(3)
    ek.enable_grad(a, b)
    d = a / b
    ek.backward(d)
    assert ek.allclose(ek.grad(a), 1.0 / 3.0)
    assert ek.allclose(ek.grad(b), -2.0 / 9.0)
コード例 #6
0
def test06_hsum_0_fwd(m):
    x = ek.linspace(m.Float, 0, 1, 10)
    ek.enable_grad(x)
    y = ek.hsum_async(x * x)
    ek.forward(x)
    assert len(y) == 1 and ek.allclose(ek.detach(y), 95.0 / 27.0)
    assert len(ek.grad(y)) == 1 and ek.allclose(ek.grad(y), 10)
コード例 #7
0
        def backward(self):
            grad_pos, grad_vel = self.grad_out()

            # Run for 100 iterations
            it = m.UInt32(100)

            loop = m.Loop(it, grad_pos, grad_vel)
            n = ek.width(grad_pos)
            while loop.cond(it > 0):
                # Retrieve loop variables, reverse chronological order
                it -= 1
                index = it * n + ek.arange(m.UInt32, n)
                pos = ek.gather(m.Array2f, self.temp_pos, index)
                vel = ek.gather(m.Array2f, self.temp_vel, index)

                # Differentiate loop body in reverse mode
                ek.enable_grad(pos, vel)
                pos_out, vel_out = self.timestep(pos, vel)
                ek.set_grad(pos_out, grad_pos)
                ek.set_grad(vel_out, grad_vel)
                ek.enqueue(pos_out, vel_out)
                ek.traverse(m.Float, reverse=True)

                # Update loop variables
                grad_pos.assign(ek.grad(pos))
                grad_vel.assign(ek.grad(vel))

            self.set_grad_in('pos', grad_pos)
            self.set_grad_in('vel', grad_vel)
コード例 #8
0
def test03_sub_mul(m):
    a, b, c = m.Float(2), m.Float(3), m.Float(4)
    ek.enable_grad(a, b, c)
    d = a * b - c
    ek.backward(d)
    assert ek.grad(a) == ek.detach(b)
    assert ek.grad(b) == ek.detach(a)
    assert ek.grad(c) == -1
コード例 #9
0
def test25_pow(m):
    x = ek.linspace(m.Float, 1, 10, 10)
    y = ek.full(m.Float, 2.0, 10)
    ek.enable_grad(x, y)
    z = x**y
    ek.backward(z)
    assert ek.allclose(ek.grad(x), ek.detach(x) * 2)
    assert ek.allclose(
        ek.grad(y),
        m.Float(0., 2.77259, 9.88751, 22.1807, 40.2359, 64.5033, 95.3496,
                133.084, 177.975, 230.259))
コード例 #10
0
def test41_replace_grad(m):
    x = m.Array3f(1, 2, 3)
    y = m.Array3f(3, 2, 1)
    ek.enable_grad(x, y)
    x2 = x * x
    y2 = y * y
    z = ek.replace_grad(x2, y2)
    z2 = z * z
    ek.backward(z2)
    assert ek.allclose(z2, [1, 16, 81])
    assert ek.grad(x) == 0
    assert ek.allclose(ek.grad(y), [12, 32, 36])
コード例 #11
0
def test44_custom_forward(m):
    d = m.Array3f(1, 2, 3)
    ek.enable_grad(d)
    d2 = ek.custom(Normalize, d)
    ek.set_grad(d, m.Array3f(5, 6, 7))
    ek.enqueue(d)
    ek.traverse(m.Float, reverse=False, retain_graph=True)
    assert ek.grad(d) == 0
    ek.set_grad(d, m.Array3f(5, 6, 7))
    assert ek.allclose(ek.grad(d2), m.Array3f(0.610883, 0.152721, -0.305441))
    ek.enqueue(d)
    ek.traverse(m.Float, reverse=False, retain_graph=False)
    assert ek.allclose(ek.grad(d2),
                       m.Array3f(0.610883, 0.152721, -0.305441) * 2)
コード例 #12
0
def test19_gather_fwd(m):
    x = ek.linspace(m.Float, -1, 1, 10)
    ek.enable_grad(x)
    y = ek.gather(m.Float, x * x, m.UInt(1, 1, 2, 3))
    ek.forward(x)
    ref = [-1.55556, -1.55556, -1.11111, -0.666667]
    assert ek.allclose(ek.grad(y), ref)
コード例 #13
0
def test15_abs(m):
    x = m.Float(-2, 2)
    ek.enable_grad(x)
    y = ek.abs(x)
    ek.backward(y)
    assert ek.allclose(ek.detach(y), [2, 2])
    assert ek.allclose(ek.grad(x), [-1, 1])
コード例 #14
0
def test14_rsqrt(m):
    x = m.Float(1, .25, 0.0625)
    ek.enable_grad(x)
    y = ek.rsqrt(x)
    ek.backward(y)
    assert ek.allclose(ek.detach(y), [1, 2, 4])
    assert ek.allclose(ek.grad(x), [-.5, -4, -32])
コード例 #15
0
def test13_sqrt(m):
    x = m.Float(1, 4, 16)
    ek.enable_grad(x)
    y = ek.sqrt(x)
    ek.backward(y)
    assert ek.allclose(ek.detach(y), [1, 2, 4])
    assert ek.allclose(ek.grad(x), [.5, .25, .125])
コード例 #16
0
def test12_hmax_fwd(m):
    x = m.Float(1, 2, 8, 5, 8)
    ek.enable_grad(x)
    y = ek.hmax_async(x)
    ek.forward(x)
    assert len(y) == 1 and ek.allclose(y[0], 8)
    assert ek.allclose(ek.grad(y), [2])  # Approximation
コード例 #17
0
def test11_hmax_rev(m):
    x = m.Float(1, 2, 8, 5, 8)
    ek.enable_grad(x)
    y = ek.hmax_async(x)
    ek.backward(y)
    assert len(y) == 1 and ek.allclose(y[0], 8)
    assert ek.allclose(ek.grad(x), [0, 0, 1, 0, 1])
コード例 #18
0
def test11_hprod(m):
    x = m.Float(1, 2, 5, 8)
    ek.enable_grad(x)
    y = ek.hprod_async(x)
    ek.backward(y)
    assert len(y) == 1 and ek.allclose(y[0], 80)
    assert ek.allclose(ek.grad(x), [80, 40, 16, 10])
コード例 #19
0
def test17_cos(m):
    x = ek.linspace(m.Float, 0.01, 10, 10)
    ek.enable_grad(x)
    y = ek.cos(x)
    ek.backward(y)
    assert ek.allclose(ek.detach(y), ek.cos(ek.detach(x)))
    assert ek.allclose(ek.grad(x), -ek.sin(ek.detach(x)))
コード例 #20
0
def test49_eager_fwd(m):
    with EagerMode():
        x = m.Float(1)
        ek.enable_grad(x)
        ek.set_grad(x, 10)
        y = x * x * x
        assert ek.grad(y) == 30
コード例 #21
0
def test51_scatter_reduce_fwd_eager(m):
    with EagerMode():
        for i in range(3):
            idx1 = ek.arange(m.UInt, 5)
            idx2 = ek.arange(m.UInt, 4) + 3

            x = ek.linspace(m.Float, 0, 1, 5)
            y = ek.linspace(m.Float, 1, 2, 4)
            buf = ek.zero(m.Float, 10)

            if i % 2 == 0:
                ek.enable_grad(buf)
                ek.set_grad(buf, 1)
            if i // 2 == 0:
                ek.enable_grad(x, y)
                ek.set_grad(x, 1)
                ek.set_grad(y, 1)

            x.label = "x"
            y.label = "y"
            buf.label = "buf"

            buf2 = m.Float(buf)
            ek.scatter_reduce(ek.ReduceOp.Add, buf2, x, idx1)
            ek.scatter_reduce(ek.ReduceOp.Add, buf2, y, idx2)

            s = ek.dot_async(buf2, buf2)

            # Verified against Mathematica
            assert ek.allclose(ek.detach(s), 15.5972)
            assert ek.allclose(ek.grad(s), (25.1667 if i // 2 == 0 else 0) +
                               (17 if i % 2 == 0 else 0))
コード例 #22
0
def test05_hsum_0_rev(m):
    x = ek.linspace(m.Float, 0, 1, 10)
    ek.enable_grad(x)
    y = ek.hsum_async(x * x)
    ek.backward(y)
    assert len(y) == 1 and ek.allclose(y, 95.0 / 27.0)
    assert ek.allclose(ek.grad(x), 2 * ek.detach(x))
コード例 #23
0
def test23_exp(m):
    x = ek.linspace(m.Float, 0, 1, 10)
    ek.enable_grad(x)
    y = ek.exp(x * x)
    ek.backward(y)
    exp_x = ek.exp(ek.sqr(ek.detach(x)))
    assert ek.allclose(y, exp_x)
    assert ek.allclose(ek.grad(x), 2 * ek.detach(x) * exp_x)
コード例 #24
0
def test09_hsum_2_rev(m):
    x = ek.linspace(m.Float, 0, 1, 11)
    ek.enable_grad(x)
    z = ek.hsum_async(ek.hsum_async(x * x) * x * x)
    ek.backward(z)
    assert ek.allclose(
        ek.grad(x),
        [0., 1.54, 3.08, 4.62, 6.16, 7.7, 9.24, 10.78, 12.32, 13.86, 15.4])
コード例 #25
0
def test18_gather(m):
    x = ek.linspace(m.Float, -1, 1, 10)
    ek.enable_grad(x)
    y = ek.gather(m.Float, x * x, m.UInt(1, 1, 2, 3))
    z = ek.hsum_async(y)
    ek.backward(z)
    ref = [0, -1.55556 * 2, -1.11111, -0.666667, 0, 0, 0, 0, 0, 0]
    assert ek.allclose(ek.grad(x), ref)
コード例 #26
0
def test50_gather_fwd_eager(m):
    with EagerMode():
        x = ek.linspace(m.Float, -1, 1, 10)
        ek.enable_grad(x)
        ek.set_grad(x, 1)
        y = ek.gather(m.Float, x * x, m.UInt(1, 1, 2, 3))
        ref = [-1.55556, -1.55556, -1.11111, -0.666667]
        assert ek.allclose(ek.grad(y), ref)
コード例 #27
0
def test43_custom_reverse(m):
    d = m.Array3f(1, 2, 3)
    ek.enable_grad(d)
    d2 = ek.custom(Normalize, d)
    ek.set_grad(d2, m.Array3f(5, 6, 7))
    ek.enqueue(d2)
    ek.traverse(m.Float, reverse=True)
    assert ek.allclose(ek.grad(d), m.Array3f(0.610883, 0.152721, -0.305441))
コード例 #28
0
def test24_log(m):
    x = ek.linspace(m.Float, 0.01, 1, 10)
    ek.enable_grad(x)
    y = ek.log(x * x)
    ek.backward(y)
    log_x = ek.log(ek.sqr(ek.detach(x)))
    assert ek.allclose(y, log_x)
    assert ek.allclose(ek.grad(x), 2 / ek.detach(x))
コード例 #29
0
def test40_safe_functions(m):
    x = ek.linspace(m.Float, 0, 1, 10)
    y = ek.linspace(m.Float, -1, 1, 10)
    z = ek.linspace(m.Float, -1, 1, 10)
    ek.enable_grad(x, y, z)
    x2 = ek.safe_sqrt(x)
    y2 = ek.safe_acos(y)
    z2 = ek.safe_asin(z)
    ek.backward(x2)
    ek.backward(y2)
    ek.backward(z2)
    assert ek.grad(x)[0] == 0
    assert ek.allclose(ek.grad(x)[1], .5 / ek.sqrt(1 / 9))
    assert x[0] == 0
    assert ek.all(ek.isfinite(ek.grad(x)))
    assert ek.all(ek.isfinite(ek.grad(y)))
    assert ek.all(ek.isfinite(ek.grad(z)))
コード例 #30
0
def test42_suspend_resume(m):
    x = m.Array3f(1, 2, 3)
    y = m.Array3f(3, 2, 1)
    ek.enable_grad(x, y)
    assert ek.grad_enabled(x) and ek.grad_enabled(y)
    assert not ek.grad_suspended(x) and not ek.grad_suspended(y)
    ek.suspend_grad(x, y)
    assert not ek.grad_enabled(x) and not ek.grad_enabled(y)
    assert ek.grad_suspended(x) and ek.grad_suspended(y)
    b = x * y
    ek.resume_grad(x, y)
    assert ek.grad_enabled(x) and ek.grad_enabled(y)
    assert not ek.grad_suspended(x) and not ek.grad_suspended(y)
    c = x * y
    ek.backward(c)
    assert ek.grad(x) == ek.detach(y)
    assert ek.grad(y) == ek.detach(x)
    ek.suspend_grad(x, y)  # validate reference counting of suspended variables