def testRepeatedDoubling(self): def f(x, y, z): return x + y + z - x - y f2 = doubledouble(f) f4 = doubledouble(f2) dtype = jnp.float32 x, y, z = dtype(1E20), dtype(1.0), dtype(1E-20) self.assertEqual(f(x, y, z), -y) self.assertEqual(f2(x, y, z), 0) self.assertEqual(f4(x, y, z), z)
def testDoubledPrecision(self, shape, dtype, op1, op2): """Test operations that would lose precision without doubling.""" rng = jtu.rand_default(self.rng()) double_op1 = doubledouble(op1) args = 1E20 * rng(shape, dtype), rng(shape, dtype) check_dtypes = not FLAGS.jax_enable_x64 self.assertAllClose(double_op1(*args), op2(*args), check_dtypes=check_dtypes) # Sanity check: make sure test fails for regular precision. with self.assertRaisesRegex(AssertionError, "Not equal to tolerance"): self.assertAllClose(op1(*args), op2(*args), check_dtypes=check_dtypes)
def testBinaryOp(self, dtype, shape, op): rng = jtu.rand_default(self.rng()) op_doubled = doubledouble(op) args = rng(shape, dtype), rng(shape, dtype) self.assertAllClose(op(*args), op_doubled(*args))
def testTypeConversion(self): x = jnp.arange(10, dtype='float16') f = lambda x, y: (x + y).astype('float32') g = doubledouble(f) self.assertAllClose(f(1E2 * x, 1E-2 * x), 1E2 * x.astype('float32')) self.assertAllClose(g(1E2 * x, 1E-2 * x), 100.01 * x.astype('float32'))