def testPromoteDtypes(self): for t1 in all_dtypes: self.assertEqual(t1, dtypes.promote_types(t1, t1)) self.assertEqual(t1, dtypes.promote_types(t1, np.bool_)) self.assertEqual(np.dtype(np.complex128), dtypes.promote_types(t1, np.complex128)) for t2 in all_dtypes: # Symmetry self.assertEqual(dtypes.promote_types(t1, t2), dtypes.promote_types(t2, t1)) self.assertEqual(np.dtype(np.float32), dtypes.promote_types(np.float16, dtypes.bfloat16)) # Promotions of non-inexact types against inexact types always prefer # the inexact types. for t in float_dtypes + complex_dtypes: for i in bool_dtypes + signed_dtypes + unsigned_dtypes: self.assertEqual(t, dtypes.promote_types(t, i)) # Promotions between exact types, or between inexact types, match NumPy. for groups in [ bool_dtypes + signed_dtypes + unsigned_dtypes, np_float_dtypes + complex_dtypes ]: for t1, t2 in itertools.combinations(groups, 2): self.assertEqual(np.promote_types(t1, t2), dtypes.promote_types(t1, t2))
def testMixedIntBool(self): tree = [jnp.array([0], jnp.bool_), jnp.array([[1, 2], [3, 4]], jnp.int32)] raveled, unravel = flatten_util.ravel_pytree(tree) self.assertEqual(raveled.dtype, dtypes.promote_types(jnp.bool_, jnp.int32)) tree_ = unravel(raveled) self.assertAllClose(tree, tree_, atol=0., rtol=0.)
def testMixedFloatComplex(self): tree = [jnp.array([1.], jnp.float32), jnp.array([[1, 2 + 3j], [3, 4]], jnp.complex64)] raveled, unravel = flatten_util.ravel_pytree(tree) self.assertEqual(raveled.dtype, dtypes.promote_types(jnp.float32, jnp.complex64)) tree_ = unravel(raveled) self.assertAllClose(tree, tree_, atol=0., rtol=0.)
def testMixedFloatInt(self): tree = [jnp.array([3], jnp.int32), jnp.array([[1., 2.], [3., 4.]], jnp.float32)] raveled, unravel = flatten_util.ravel_pytree(tree) self.assertEqual(raveled.dtype, dtypes.promote_types(jnp.float32, jnp.int32)) tree_ = unravel(raveled) self.assertAllClose(tree, tree_, atol=0., rtol=0.)