def test09_hsum_2_rev(m): x = ek.linspace(m.Float, 0, 1, 11) ek.enable_grad(x) z = ek.hsum_async(ek.hsum_async(x * x) * x * x) ek.backward(z) assert ek.allclose( ek.grad(x), [0., 1.54, 3.08, 4.62, 6.16, 7.7, 9.24, 10.78, 12.32, 13.86, 15.4])
def test06_hsum_0_fwd(m): x = ek.linspace(m.Float, 0, 1, 10) ek.enable_grad(x) y = ek.hsum_async(x * x) ek.forward(x) assert len(y) == 1 and ek.allclose(ek.detach(y), 95.0 / 27.0) assert len(ek.grad(y)) == 1 and ek.allclose(ek.grad(y), 10)
def test05_hsum_0_rev(m): x = ek.linspace(m.Float, 0, 1, 10) ek.enable_grad(x) y = ek.hsum_async(x * x) ek.backward(y) assert len(y) == 1 and ek.allclose(y, 95.0 / 27.0) assert ek.allclose(ek.grad(x), 2 * ek.detach(x))
def test18_gather(m): x = ek.linspace(m.Float, -1, 1, 10) ek.enable_grad(x) y = ek.gather(m.Float, x * x, m.UInt(1, 1, 2, 3)) z = ek.hsum_async(y) ek.backward(z) ref = [0, -1.55556 * 2, -1.11111, -0.666667, 0, 0, 0, 0, 0, 0] assert ek.allclose(ek.grad(x), ref)
def test07_hsum_1_rev(m): x = ek.linspace(m.Float, 0, 1, 11) ek.enable_grad(x) y = ek.hsum_async(ek.hsum_async(x) * x) ek.backward(y) assert ek.allclose(ek.grad(x), 11)
def test10_hsum_2_fwd(m): x = ek.linspace(m.Float, 0, 1, 10) ek.enable_grad(x) y = ek.hsum_async(ek.hsum_async(x * x) * ek.hsum_async(x * x)) ek.forward(x) assert ek.allclose(ek.grad(y), 1900.0 / 27.0)
def test08_hsum_1_fwd(m): x = ek.linspace(m.Float, 0, 1, 10) ek.enable_grad(x) y = ek.hsum_async(ek.hsum_async(x) * x) ek.forward(x) assert ek.allclose(ek.grad(y), 100)