예제 #1
0
    def testPromoteDtypes(self):
        for t1 in all_dtypes:
            self.assertEqual(t1, dtypes.promote_types(t1, t1))

            self.assertEqual(t1, dtypes.promote_types(t1, np.bool_))
            self.assertEqual(np.dtype(np.complex128),
                             dtypes.promote_types(t1, np.complex128))

            for t2 in all_dtypes:
                # Symmetry
                self.assertEqual(dtypes.promote_types(t1, t2),
                                 dtypes.promote_types(t2, t1))

        self.assertEqual(np.dtype(np.float32),
                         dtypes.promote_types(np.float16, dtypes.bfloat16))

        # Promotions of non-inexact types against inexact types always prefer
        # the inexact types.
        for t in float_dtypes + complex_dtypes:
            for i in bool_dtypes + signed_dtypes + unsigned_dtypes:
                self.assertEqual(t, dtypes.promote_types(t, i))

        # Promotions between exact types, or between inexact types, match NumPy.
        for groups in [
                bool_dtypes + signed_dtypes + unsigned_dtypes,
                np_float_dtypes + complex_dtypes
        ]:
            for t1, t2 in itertools.combinations(groups, 2):
                self.assertEqual(np.promote_types(t1, t2),
                                 dtypes.promote_types(t1, t2))
예제 #2
0
 def testMixedIntBool(self):
   tree = [jnp.array([0], jnp.bool_),
           jnp.array([[1, 2], [3, 4]], jnp.int32)]
   raveled, unravel = flatten_util.ravel_pytree(tree)
   self.assertEqual(raveled.dtype, dtypes.promote_types(jnp.bool_, jnp.int32))
   tree_ = unravel(raveled)
   self.assertAllClose(tree, tree_, atol=0., rtol=0.)
예제 #3
0
 def testMixedFloatComplex(self):
   tree = [jnp.array([1.], jnp.float32),
           jnp.array([[1, 2 + 3j], [3, 4]], jnp.complex64)]
   raveled, unravel = flatten_util.ravel_pytree(tree)
   self.assertEqual(raveled.dtype, dtypes.promote_types(jnp.float32, jnp.complex64))
   tree_ = unravel(raveled)
   self.assertAllClose(tree, tree_, atol=0., rtol=0.)
예제 #4
0
 def testMixedFloatInt(self):
   tree = [jnp.array([3], jnp.int32),
           jnp.array([[1., 2.], [3., 4.]], jnp.float32)]
   raveled, unravel = flatten_util.ravel_pytree(tree)
   self.assertEqual(raveled.dtype, dtypes.promote_types(jnp.float32, jnp.int32))
   tree_ = unravel(raveled)
   self.assertAllClose(tree, tree_, atol=0., rtol=0.)