Esempio n. 1
0
  def testRepeatedDoubling(self):
    def f(x, y, z):
      return x + y + z - x - y
    f2 = doubledouble(f)
    f4 = doubledouble(f2)
    dtype = jnp.float32
    x, y, z = dtype(1E20), dtype(1.0), dtype(1E-20)

    self.assertEqual(f(x, y, z), -y)
    self.assertEqual(f2(x, y, z), 0)
    self.assertEqual(f4(x, y, z), z)
Esempio n. 2
0
  def testDoubledPrecision(self, shape, dtype, op1, op2):
    """Test operations that would lose precision without doubling."""
    rng = jtu.rand_default(self.rng())
    double_op1 = doubledouble(op1)
    args = 1E20 * rng(shape, dtype), rng(shape, dtype)
    check_dtypes = not FLAGS.jax_enable_x64

    self.assertAllClose(double_op1(*args), op2(*args), check_dtypes=check_dtypes)

    # Sanity check: make sure test fails for regular precision.
    with self.assertRaisesRegex(AssertionError, "Not equal to tolerance"):
      self.assertAllClose(op1(*args), op2(*args), check_dtypes=check_dtypes)
Esempio n. 3
0
 def testBinaryOp(self, dtype, shape, op):
     rng = jtu.rand_default(self.rng())
     op_doubled = doubledouble(op)
     args = rng(shape, dtype), rng(shape, dtype)
     self.assertAllClose(op(*args), op_doubled(*args))
Esempio n. 4
0
 def testTypeConversion(self):
     x = jnp.arange(10, dtype='float16')
     f = lambda x, y: (x + y).astype('float32')
     g = doubledouble(f)
     self.assertAllClose(f(1E2 * x, 1E-2 * x), 1E2 * x.astype('float32'))
     self.assertAllClose(g(1E2 * x, 1E-2 * x), 100.01 * x.astype('float32'))