예제 #1
0
def test51_scatter_reduce_fwd_eager(m):
    with EagerMode():
        for i in range(3):
            idx1 = ek.arange(m.UInt, 5)
            idx2 = ek.arange(m.UInt, 4) + 3

            x = ek.linspace(m.Float, 0, 1, 5)
            y = ek.linspace(m.Float, 1, 2, 4)
            buf = ek.zero(m.Float, 10)

            if i % 2 == 0:
                ek.enable_grad(buf)
                ek.set_grad(buf, 1)
            if i // 2 == 0:
                ek.enable_grad(x, y)
                ek.set_grad(x, 1)
                ek.set_grad(y, 1)

            x.label = "x"
            y.label = "y"
            buf.label = "buf"

            buf2 = m.Float(buf)
            ek.scatter_reduce(ek.ReduceOp.Add, buf2, x, idx1)
            ek.scatter_reduce(ek.ReduceOp.Add, buf2, y, idx2)

            s = ek.dot_async(buf2, buf2)

            # Verified against Mathematica
            assert ek.allclose(ek.detach(s), 15.5972)
            assert ek.allclose(ek.grad(s), (25.1667 if i // 2 == 0 else 0) +
                               (17 if i % 2 == 0 else 0))
예제 #2
0
def test_58_diffloop_masking_rev(m, no_record):
    fo = ek.zero(m.Float, 10)
    fi = m.Float(1, 2)
    i = m.UInt32(0, 5)
    ek.enable_grad(fi)
    loop = m.Loop("MyLoop", lambda: i)
    while loop(i < 5):
        ek.scatter_reduce(ek.ReduceOp.Add, fo, fi, i)
        i += 1
    ek.backward(fo)
    assert ek.grad(fi) == m.Float(5, 0)
예제 #3
0
def test20_scatter_reduce_rev(m):
    for i in range(3):
        idx1 = ek.arange(m.UInt, 5)
        idx2 = ek.arange(m.UInt, 4) + 3

        x = ek.linspace(m.Float, 0, 1, 5)
        y = ek.linspace(m.Float, 1, 2, 4)
        buf = ek.zero(m.Float, 10)

        if i % 2 == 0:
            ek.enable_grad(buf)
        if i // 2 == 0:
            ek.enable_grad(x, y)

        x.label = "x"
        y.label = "y"
        buf.label = "buf"

        buf2 = m.Float(buf)
        ek.scatter_reduce(ek.ReduceOp.Add, buf2, x, idx1)
        ek.scatter_reduce(ek.ReduceOp.Add, buf2, y, idx2)

        ref_buf = m.Float(0.0000, 0.2500, 0.5000, 1.7500, 2.3333, 1.6667,
                          2.0000, 0.0000, 0.0000, 0.0000)

        assert ek.allclose(ref_buf, buf2, atol=1e-4)

        s = ek.dot_async(buf2, buf2)

        ek.backward(s)

        ref_x = m.Float(0.0000, 0.5000, 1.0000, 3.5000, 4.6667)
        ref_y = m.Float(3.5000, 4.6667, 3.3333, 4.0000)

        if i // 2 == 0:
            assert ek.allclose(ek.grad(y), ek.detach(ref_y), atol=1e-4)
            assert ek.allclose(ek.grad(x), ek.detach(ref_x), atol=1e-4)
        else:
            assert ek.grad(x) == 0
            assert ek.grad(y) == 0

        if i % 2 == 0:
            assert ek.allclose(ek.grad(buf), ek.detach(ref_buf) * 2, atol=1e-4)
        else:
            assert ek.grad(buf) == 0
예제 #4
0
def test05_side_effect_noloop(pkg):
    p = get_class(pkg)

    i = ek.zero(p.Int, 10)
    j = ek.zero(p.Int, 10)
    buf = ek.zero(p.Float, 10)
    ek.set_flag(ek.JitFlag.LoopRecord, False)

    loop = p.Loop("MyLoop", lambda: (i, j))
    while loop(i < 10):
        j += i
        i += 1
        ek.scatter_reduce(op=ek.ReduceOp.Add,
                          target=buf,
                          value=p.Float(i),
                          index=0)

    assert i == p.Int([10] * 10)
    assert buf == p.Float(550, *([0] * 9))
    assert j == p.Int([45] * 10)
예제 #5
0
파일: generic.py 프로젝트: wjakob/enoki
def scatter_reduce_(self, op, target, index, mask):
    assert target.Depth == 1
    sr = max(len(self), len(index), len(mask))
    for i in range(sr):
        _ek.scatter_reduce(op, target, self[i], index[i], mask[i])