示例#1
0
def test16_custom(cname):
    t = get_class(cname)

    v1 = ek.zero(t, 100)
    v2 = ek.empty(t, 100)
    assert len(v1.state) == 100
    assert len(v2.inc) == 100

    v2.state = v1.state
    v1.state = ek.arange(type(v1.state), 100)
    v3 = ek.select(v1.state < 10, v1, v2)
    assert v3.state[3] == 3
    assert v3.state[11] == 0

    assert ek.width(v3) == 100
    v4 = ek.zero(t, 1)
    ek.schedule(v4)
    ek.resize(v4, 200)
    assert ek.width(v4) == 200

    assert ek.width(v3) == 100
    v4 = ek.zero(t, 1)
    ek.resize(v4, 200)
    assert ek.width(v4) == 200

    index = ek.arange(type(v1.state), 100)
    ek.scatter(v4, v1, index)
    v5 = ek.gather(t, v4, index)
    ek.eval(v5)
    assert v5.state == v1.state and v5.inc == v1.inc
示例#2
0
        def eval(self, pos, vel):
            pos, vel = m.Array2f(pos), m.Array2f(vel)

            # Run for 100 iterations
            it, max_it = m.UInt32(0), 100

            # Allocate scratch space
            n = max(ek.width(pos), ek.width(vel))
            self.temp_pos = ek.empty(m.Array2f, n * max_it)
            self.temp_vel = ek.empty(m.Array2f, n * max_it)

            loop = m.Loop(pos, vel, it)
            while loop.cond(it < max_it):
                # Store current loop variables
                index = it * n + ek.arange(m.UInt32, n)
                ek.scatter(self.temp_pos, pos, index)
                ek.scatter(self.temp_vel, vel, index)

                # Update loop variables
                pos_out, vel_out = self.timestep(pos, vel)
                pos.assign(pos_out)
                vel.assign(vel_out)

                it += 1

            # Ensure output and temp. arrays are evaluated at this point
            ek.eval(pos, vel)

            return pos, vel
示例#3
0
def test22_scatter_fwd(m):
    x = m.Float(4.0)
    ek.enable_grad(x)

    values = x * x * ek.linspace(m.Float, 1, 4, 4)
    idx = 2 * ek.arange(m.UInt32, 4)

    buf = ek.zero(m.Float, 10)
    ek.scatter(buf, values, idx)

    assert ek.grad_enabled(buf)

    ref = [16.0, 0.0, 32.0, 0.0, 48.0, 0.0, 64.0, 0.0, 0.0, 0.0]
    assert ek.allclose(buf, ref)

    ek.forward(x, retain_graph=True)
    grad = ek.grad(buf)

    ref_grad = [8.0, 0.0, 16.0, 0.0, 24.0, 0.0, 32.0, 0.0, 0.0, 0.0]
    assert ek.allclose(grad, ref_grad)

    # Overwrite first value with non-diff value, resulting gradient entry should be 0
    y = m.Float(3)
    idx = m.UInt32(0)
    ek.scatter(buf, y, idx)

    ref = [3.0, 0.0, 32.0, 0.0, 48.0, 0.0, 64.0, 0.0, 0.0, 0.0]
    assert ek.allclose(buf, ref)

    ek.forward(x)
    grad = ek.grad(buf)

    ref_grad = [0.0, 0.0, 16.0, 0.0, 24.0, 0.0, 32.0, 0.0, 0.0, 0.0]
    assert ek.allclose(grad, ref_grad)
示例#4
0
def test07_loop_nest(pkg, variant):
    p = get_class(pkg)

    def collatz(value: p.Int):
        counter = p.Int(0)
        loop = p.Loop(value, counter)
        while (loop.cond(ek.neq(value, 1))):
            is_even = ek.eq(value & 1, 0)
            value.assign(ek.select(is_even, value // 2, 3 * value + 1))
            counter += 1
        return counter

    i = p.Int(1)
    buf = ek.full(p.Int, 1000, 16)
    ek.eval(buf)

    if variant == 0:
        loop_1 = p.Loop(i)
        while loop_1.cond(i <= 10):
            ek.scatter(buf, collatz(p.Int(i)), i - 1)
            i += 1
    else:
        for i in range(1, 11):
            ek.scatter(buf, collatz(p.Int(i)), i - 1)
            i += 1

    assert buf == p.Int(0, 1, 7, 2, 5, 8, 16, 3, 19, 6, 1000, 1000, 1000, 1000,
                        1000, 1000)
示例#5
0
def test22_scatter_rev(m):
    for i in range(3):
        idx1 = ek.arange(m.UInt, 5)
        idx2 = ek.arange(m.UInt, 4) + 3

        x = ek.linspace(m.Float, 0, 1, 5)
        y = ek.linspace(m.Float, 1, 2, 4)
        buf = ek.zero(m.Float, 10)

        if i % 2 == 0:
            ek.enable_grad(buf)
        if i // 2 == 0:
            ek.enable_grad(x, y)

        x.label = "x"
        y.label = "y"
        buf.label = "buf"

        buf2 = m.Float(buf)
        ek.scatter(buf2, x, idx1)
        ek.eval(buf2)
        ek.scatter(buf2, y, idx2)

        ref_buf = m.Float(0.0000, 0.2500, 0.5000, 1.0000, 1.3333, 1.6667,
                          2.0000, 0.0000, 0.0000, 0.0000)

        assert ek.allclose(ref_buf, buf2, atol=1e-4)
        assert ek.allclose(ref_buf, buf, atol=1e-4)

        s = ek.dot_async(buf2, buf2)

        ek.backward(s)

        ref_x = m.Float(0.0000, 0.5000, 1.0000, 0.0000, 0.0000)
        ref_y = m.Float(2.0000, 2.6667, 3.3333, 4.0000)

        if i // 2 == 0:
            assert ek.allclose(ek.grad(y), ek.detach(ref_y), atol=1e-4)
            assert ek.allclose(ek.grad(x), ek.detach(ref_x), atol=1e-4)
        else:
            assert ek.grad(x) == 0
            assert ek.grad(y) == 0

        if i % 2 == 0:
            assert ek.allclose(ek.grad(buf), 0, atol=1e-4)
        else:
            assert ek.grad(buf) == 0
示例#6
0
def test22_scatter_fwd_permute(m):
    x = m.Float(4.0)
    ek.enable_grad(x)

    values_0 = x * ek.linspace(m.Float, 1, 9, 5)
    values_1 = x * ek.linspace(m.Float, 11, 19, 5)

    buf = ek.zero(m.Float, 10)

    idx_0 = ek.arange(m.UInt32, 5)
    idx_1 = ek.arange(m.UInt32, 5) + 5

    ek.scatter(buf, values_0, idx_0, permute=False)
    ek.scatter(buf, values_1, idx_1, permute=False)

    ref = [4.0, 12.0, 20.0, 28.0, 36.0, 44.0, 52.0, 60.0, 68.0, 76.0]
    assert ek.allclose(buf, ref)

    ek.forward(x)
    grad = ek.grad(buf)

    ref_grad = [1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0]
    assert ek.allclose(grad, ref_grad)
 def unravel(source, target, dim=3):
     idx = UInt32.arange(ek.slices(source))
     for i in range(dim):
         ek.scatter(target, source[i], dim * idx + i)
示例#8
0
def scatter_(self, target, index, mask, permute):
    assert target.Depth == 1
    sr = max(len(self), len(index), len(mask))
    for i in range(sr):
        _ek.scatter(target, self[i], index[i], mask[i], permute)
示例#9
0
def test54_scatter_implicit_detach(m):
    x = ek.detach(m.Float(0))
    y = ek.detach(m.Float(1))
    i = m.UInt32(0)
    m = m.Bool(True)
    ek.scatter(x, y, i, m)