Esempio n. 1
0
def test05_scalar(t):
    if not ek.is_array_v(t) or ek.array_size_v(t) == 0:
        return
    get_class(t.__module__)

    if ek.is_mask_v(t):
        assert ek.all_nested(t(True))
        assert ek.any_nested(t(True))
        assert ek.none_nested(t(False))
        assert ek.all_nested(t(False) ^ t(True))
        assert ek.all_nested(ek.eq(t(False), t(False)))
        assert ek.none_nested(ek.eq(t(True), t(False)))

    if ek.is_arithmetic_v(t):
        assert t(1) + t(1) == t(2)
        assert t(3) - t(1) == t(2)
        assert t(2) * t(2) == t(4)
        assert ek.min(t(2), t(3)) == t(2)
        assert ek.max(t(2), t(3)) == t(3)

        if ek.is_signed_v(t):
            assert t(2) * t(-2) == t(-4)
            assert ek.abs(t(-2)) == t(2)

        if ek.is_integral_v(t):
            assert t(6) // t(2) == t(3)
            assert t(7) % t(2) == t(1)
            assert t(7) >> 1 == t(3)
            assert t(7) << 1 == t(14)
            assert t(1) | t(2) == t(3)
            assert t(1) ^ t(3) == t(2)
            assert t(1) & t(3) == t(1)
        else:
            assert t(6) / t(2) == t(3)
            assert ek.sqrt(t(4)) == t(2)
            assert ek.fmadd(t(1), t(2), t(3)) == t(5)
            assert ek.fmsub(t(1), t(2), t(3)) == t(-1)
            assert ek.fnmadd(t(1), t(2), t(3)) == t(1)
            assert ek.fnmsub(t(1), t(2), t(3)) == t(-5)
            assert (t(1) & True) == t(1)
            assert (t(1) & False) == t(0)
            assert (t(1) | False) == t(1)

        assert ek.all_nested(t(3) > t(2))
        assert ek.all_nested(ek.eq(t(2), t(2)))
        assert ek.all_nested(ek.neq(t(3), t(2)))
        assert ek.all_nested(t(1) >= t(1))
        assert ek.all_nested(t(2) < t(3))
        assert ek.all_nested(t(1) <= t(1))
        assert ek.select(ek.eq(t(2), t(2)), t(4), t(5)) == t(4)
        assert ek.select(ek.eq(t(3), t(2)), t(4), t(5)) == t(5)
        t2 = t(2)
        assert ek.hsum(t2) == t.Value(2 * len(t2))
        assert ek.dot(t2, t2) == t.Value(4 * len(t2))
        assert ek.dot_async(t2, t2) == t(4 * len(t2))

        value = t(1)
        value[ek.eq(value, t(1))] = t(2)
        value[ek.eq(value, t(3))] = t(5)
        assert value == t(2)
Esempio n. 2
0
def test51_scatter_reduce_fwd_eager(m):
    with EagerMode():
        for i in range(3):
            idx1 = ek.arange(m.UInt, 5)
            idx2 = ek.arange(m.UInt, 4) + 3

            x = ek.linspace(m.Float, 0, 1, 5)
            y = ek.linspace(m.Float, 1, 2, 4)
            buf = ek.zero(m.Float, 10)

            if i % 2 == 0:
                ek.enable_grad(buf)
                ek.set_grad(buf, 1)
            if i // 2 == 0:
                ek.enable_grad(x, y)
                ek.set_grad(x, 1)
                ek.set_grad(y, 1)

            x.label = "x"
            y.label = "y"
            buf.label = "buf"

            buf2 = m.Float(buf)
            ek.scatter_reduce(ek.ReduceOp.Add, buf2, x, idx1)
            ek.scatter_reduce(ek.ReduceOp.Add, buf2, y, idx2)

            s = ek.dot_async(buf2, buf2)

            # Verified against Mathematica
            assert ek.allclose(ek.detach(s), 15.5972)
            assert ek.allclose(ek.grad(s), (25.1667 if i // 2 == 0 else 0) +
                               (17 if i % 2 == 0 else 0))
Esempio n. 3
0
def test20_scatter_add_rev(m):
    for i in range(3):
        idx1 = ek.arange(m.UInt, 5)
        idx2 = ek.arange(m.UInt, 4) + 3

        x = ek.linspace(m.Float, 0, 1, 5)
        y = ek.linspace(m.Float, 1, 2, 4)
        buf = ek.zero(m.Float, 10)

        if i % 2 == 0:
            ek.enable_grad(buf)
        if i // 2 == 0:
            ek.enable_grad(x, y)

        x.label = "x"
        y.label = "y"
        buf.label = "buf"

        buf2 = m.Float(buf)
        ek.scatter_add(buf2, x, idx1)
        ek.scatter_add(buf2, y, idx2)

        ref_buf = m.Float(0.0000, 0.2500, 0.5000, 1.7500, 2.3333, 1.6667,
                          2.0000, 0.0000, 0.0000, 0.0000)

        assert ek.allclose(ref_buf, buf2, atol=1e-4)
        assert ek.allclose(ref_buf, buf, atol=1e-4)

        s = ek.dot_async(buf2, buf2)

        print(ek.graphviz_str(s))

        ek.backward(s)

        ref_x = m.Float(0.0000, 0.5000, 1.0000, 3.5000, 4.6667)
        ref_y = m.Float(3.5000, 4.6667, 3.3333, 4.0000)

        if i // 2 == 0:
            assert ek.allclose(ek.grad(y), ek.detach(ref_y), atol=1e-4)
            assert ek.allclose(ek.grad(x), ek.detach(ref_x), atol=1e-4)
        else:
            assert ek.grad(x) == 0
            assert ek.grad(y) == 0

        if i % 2 == 0:
            assert ek.allclose(ek.grad(buf), ek.detach(ref_buf) * 2, atol=1e-4)
        else:
            assert ek.grad(buf) == 0
Esempio n. 4
0
def test21_scatter_add_fwd(m):
    for i in range(3):
        idx1 = ek.arange(m.UInt, 5)
        idx2 = ek.arange(m.UInt, 4) + 3

        x = ek.linspace(m.Float, 0, 1, 5)
        y = ek.linspace(m.Float, 1, 2, 4)
        buf = ek.zero(m.Float, 10)

        if i % 2 == 0:
            ek.enable_grad(buf)
            ek.set_grad(buf, 1)
        if i // 2 == 0:
            ek.enable_grad(x, y)
            ek.set_grad(x, 1)
            ek.set_grad(y, 1)

        x.label = "x"
        y.label = "y"
        buf.label = "buf"

        buf2 = m.Float(buf)
        ek.scatter_add(buf2, x, idx1)
        ek.scatter_add(buf2, y, idx2)

        s = ek.dot_async(buf2, buf2)

        if i % 2 == 0:
            ek.enqueue(buf)
        if i // 2 == 0:
            ek.enqueue(x, y)

        ek.traverse(m.Float, reverse=False)

        # Verified against Mathematica
        assert ek.allclose(ek.detach(s), 15.5972)
        assert ek.allclose(ek.grad(s), (25.1667 if i // 2 == 0 else 0) +
                           (17 if i % 2 == 0 else 0))