Example #1
0
def test09_hsum_2_rev(m):
    x = ek.linspace(m.Float, 0, 1, 11)
    ek.enable_grad(x)
    z = ek.hsum_async(ek.hsum_async(x * x) * x * x)
    ek.backward(z)
    assert ek.allclose(
        ek.grad(x),
        [0., 1.54, 3.08, 4.62, 6.16, 7.7, 9.24, 10.78, 12.32, 13.86, 15.4])
Example #2
0
def test06_hsum_0_fwd(m):
    x = ek.linspace(m.Float, 0, 1, 10)
    ek.enable_grad(x)
    y = ek.hsum_async(x * x)
    ek.forward(x)
    assert len(y) == 1 and ek.allclose(ek.detach(y), 95.0 / 27.0)
    assert len(ek.grad(y)) == 1 and ek.allclose(ek.grad(y), 10)
Example #3
0
def test05_hsum_0_rev(m):
    x = ek.linspace(m.Float, 0, 1, 10)
    ek.enable_grad(x)
    y = ek.hsum_async(x * x)
    ek.backward(y)
    assert len(y) == 1 and ek.allclose(y, 95.0 / 27.0)
    assert ek.allclose(ek.grad(x), 2 * ek.detach(x))
Example #4
0
def test18_gather(m):
    x = ek.linspace(m.Float, -1, 1, 10)
    ek.enable_grad(x)
    y = ek.gather(m.Float, x * x, m.UInt(1, 1, 2, 3))
    z = ek.hsum_async(y)
    ek.backward(z)
    ref = [0, -1.55556 * 2, -1.11111, -0.666667, 0, 0, 0, 0, 0, 0]
    assert ek.allclose(ek.grad(x), ref)
Example #5
0
def test07_hsum_1_rev(m):
    x = ek.linspace(m.Float, 0, 1, 11)
    ek.enable_grad(x)
    y = ek.hsum_async(ek.hsum_async(x) * x)
    ek.backward(y)
    assert ek.allclose(ek.grad(x), 11)
Example #6
0
def test10_hsum_2_fwd(m):
    x = ek.linspace(m.Float, 0, 1, 10)
    ek.enable_grad(x)
    y = ek.hsum_async(ek.hsum_async(x * x) * ek.hsum_async(x * x))
    ek.forward(x)
    assert ek.allclose(ek.grad(y), 1900.0 / 27.0)
Example #7
0
def test08_hsum_1_fwd(m):
    x = ek.linspace(m.Float, 0, 1, 10)
    ek.enable_grad(x)
    y = ek.hsum_async(ek.hsum_async(x) * x)
    ek.forward(x)
    assert ek.allclose(ek.grad(y), 100)