def test05_scalar(t): if not ek.is_array_v(t) or ek.array_size_v(t) == 0: return get_class(t.__module__) if ek.is_mask_v(t): assert ek.all_nested(t(True)) assert ek.any_nested(t(True)) assert ek.none_nested(t(False)) assert ek.all_nested(t(False) ^ t(True)) assert ek.all_nested(ek.eq(t(False), t(False))) assert ek.none_nested(ek.eq(t(True), t(False))) if ek.is_arithmetic_v(t): assert t(1) + t(1) == t(2) assert t(3) - t(1) == t(2) assert t(2) * t(2) == t(4) assert ek.min(t(2), t(3)) == t(2) assert ek.max(t(2), t(3)) == t(3) if ek.is_signed_v(t): assert t(2) * t(-2) == t(-4) assert ek.abs(t(-2)) == t(2) if ek.is_integral_v(t): assert t(6) // t(2) == t(3) assert t(7) % t(2) == t(1) assert t(7) >> 1 == t(3) assert t(7) << 1 == t(14) assert t(1) | t(2) == t(3) assert t(1) ^ t(3) == t(2) assert t(1) & t(3) == t(1) else: assert t(6) / t(2) == t(3) assert ek.sqrt(t(4)) == t(2) assert ek.fmadd(t(1), t(2), t(3)) == t(5) assert ek.fmsub(t(1), t(2), t(3)) == t(-1) assert ek.fnmadd(t(1), t(2), t(3)) == t(1) assert ek.fnmsub(t(1), t(2), t(3)) == t(-5) assert (t(1) & True) == t(1) assert (t(1) & False) == t(0) assert (t(1) | False) == t(1) assert ek.all_nested(t(3) > t(2)) assert ek.all_nested(ek.eq(t(2), t(2))) assert ek.all_nested(ek.neq(t(3), t(2))) assert ek.all_nested(t(1) >= t(1)) assert ek.all_nested(t(2) < t(3)) assert ek.all_nested(t(1) <= t(1)) assert ek.select(ek.eq(t(2), t(2)), t(4), t(5)) == t(4) assert ek.select(ek.eq(t(3), t(2)), t(4), t(5)) == t(5) t2 = t(2) assert ek.hsum(t2) == t.Value(2 * len(t2)) assert ek.dot(t2, t2) == t.Value(4 * len(t2)) assert ek.dot_async(t2, t2) == t(4 * len(t2)) value = t(1) value[ek.eq(value, t(1))] = t(2) value[ek.eq(value, t(3))] = t(5) assert value == t(2)
def test51_scatter_reduce_fwd_eager(m): with EagerMode(): for i in range(3): idx1 = ek.arange(m.UInt, 5) idx2 = ek.arange(m.UInt, 4) + 3 x = ek.linspace(m.Float, 0, 1, 5) y = ek.linspace(m.Float, 1, 2, 4) buf = ek.zero(m.Float, 10) if i % 2 == 0: ek.enable_grad(buf) ek.set_grad(buf, 1) if i // 2 == 0: ek.enable_grad(x, y) ek.set_grad(x, 1) ek.set_grad(y, 1) x.label = "x" y.label = "y" buf.label = "buf" buf2 = m.Float(buf) ek.scatter_reduce(ek.ReduceOp.Add, buf2, x, idx1) ek.scatter_reduce(ek.ReduceOp.Add, buf2, y, idx2) s = ek.dot_async(buf2, buf2) # Verified against Mathematica assert ek.allclose(ek.detach(s), 15.5972) assert ek.allclose(ek.grad(s), (25.1667 if i // 2 == 0 else 0) + (17 if i % 2 == 0 else 0))
def test20_scatter_add_rev(m): for i in range(3): idx1 = ek.arange(m.UInt, 5) idx2 = ek.arange(m.UInt, 4) + 3 x = ek.linspace(m.Float, 0, 1, 5) y = ek.linspace(m.Float, 1, 2, 4) buf = ek.zero(m.Float, 10) if i % 2 == 0: ek.enable_grad(buf) if i // 2 == 0: ek.enable_grad(x, y) x.label = "x" y.label = "y" buf.label = "buf" buf2 = m.Float(buf) ek.scatter_add(buf2, x, idx1) ek.scatter_add(buf2, y, idx2) ref_buf = m.Float(0.0000, 0.2500, 0.5000, 1.7500, 2.3333, 1.6667, 2.0000, 0.0000, 0.0000, 0.0000) assert ek.allclose(ref_buf, buf2, atol=1e-4) assert ek.allclose(ref_buf, buf, atol=1e-4) s = ek.dot_async(buf2, buf2) print(ek.graphviz_str(s)) ek.backward(s) ref_x = m.Float(0.0000, 0.5000, 1.0000, 3.5000, 4.6667) ref_y = m.Float(3.5000, 4.6667, 3.3333, 4.0000) if i // 2 == 0: assert ek.allclose(ek.grad(y), ek.detach(ref_y), atol=1e-4) assert ek.allclose(ek.grad(x), ek.detach(ref_x), atol=1e-4) else: assert ek.grad(x) == 0 assert ek.grad(y) == 0 if i % 2 == 0: assert ek.allclose(ek.grad(buf), ek.detach(ref_buf) * 2, atol=1e-4) else: assert ek.grad(buf) == 0
def test21_scatter_add_fwd(m): for i in range(3): idx1 = ek.arange(m.UInt, 5) idx2 = ek.arange(m.UInt, 4) + 3 x = ek.linspace(m.Float, 0, 1, 5) y = ek.linspace(m.Float, 1, 2, 4) buf = ek.zero(m.Float, 10) if i % 2 == 0: ek.enable_grad(buf) ek.set_grad(buf, 1) if i // 2 == 0: ek.enable_grad(x, y) ek.set_grad(x, 1) ek.set_grad(y, 1) x.label = "x" y.label = "y" buf.label = "buf" buf2 = m.Float(buf) ek.scatter_add(buf2, x, idx1) ek.scatter_add(buf2, y, idx2) s = ek.dot_async(buf2, buf2) if i % 2 == 0: ek.enqueue(buf) if i // 2 == 0: ek.enqueue(x, y) ek.traverse(m.Float, reverse=False) # Verified against Mathematica assert ek.allclose(ek.detach(s), 15.5972) assert ek.allclose(ek.grad(s), (25.1667 if i // 2 == 0 else 0) + (17 if i % 2 == 0 else 0))