def testAllowF64TruePreferF32True(self): np_dtypes.set_allow_float64(True) np_dtypes.set_prefer_float32(True) self.assertEqual(dtypes.float32, np_dtypes.default_float_type()) self.assertEqual(dtypes.float32, np_dtypes._result_type(1.1)) self.assertEqual(dtypes.float64, np_dtypes._result_type(np.zeros([], np.float64), 1.1))
def result_type(*arrays_and_dtypes): """Returns the type resulting from applying NumPy type promotion to arguments. Args: *arrays_and_dtypes: A list of array_like objects or dtypes. Returns: A numpy dtype. """ def maybe_get_dtype(x): # Don't put np.ndarray in this list, because np.result_type looks at the # value (not just dtype) of np.ndarray to decide the result type. if isinstance(x, (np_arrays.ndarray, ops.Tensor, indexed_slices.IndexedSlices)): return _to_numpy_type(x.dtype) elif isinstance(x, dtypes.DType): return _to_numpy_type(x) return x arrays_and_dtypes = [ maybe_get_dtype(x) for x in nest.flatten(arrays_and_dtypes) ] if not arrays_and_dtypes: # If arrays_and_dtypes is an empty list, let numpy decide what the dtype is. arrays_and_dtypes = [np.asarray([])] return np_dtypes._result_type(*arrays_and_dtypes) # pylint: disable=protected-access
def _result_type_binary(t1, t2): # pylint: disable=missing-function-docstring """A specialization of result_type for 2 arguments for performance reasons.""" try: return np_dtypes._result_type(_maybe_get_dtype(t1), # pylint: disable=protected-access _maybe_get_dtype(t2)) # pylint: disable=protected-access except ValueError: return result_type(t1, t2)
def result_type(*arrays_and_dtypes): # pylint: disable=missing-function-docstring arrays_and_dtypes = [ _maybe_get_dtype(x) for x in nest.flatten(arrays_and_dtypes) ] if not arrays_and_dtypes: # If arrays_and_dtypes is an empty list, let numpy decide what the dtype is. arrays_and_dtypes = [np.asarray([])] return np_dtypes._result_type(*arrays_and_dtypes) # pylint: disable=protected-access
def result_type(*arrays_and_dtypes): # pylint: disable=missing-function-docstring def maybe_get_dtype(x): # Don't put np.ndarray in this list, because np.result_type looks at the # value (not just dtype) of np.ndarray to decide the result type. if isinstance( x, (np_arrays.ndarray, core.Tensor, indexed_slices.IndexedSlices)): return _to_numpy_type(x.dtype) elif isinstance(x, dtypes.DType): return _to_numpy_type(x) return x arrays_and_dtypes = [ maybe_get_dtype(x) for x in nest.flatten(arrays_and_dtypes) ] if not arrays_and_dtypes: # If arrays_and_dtypes is an empty list, let numpy decide what the dtype is. arrays_and_dtypes = [np.asarray([])] return np_dtypes._result_type(*arrays_and_dtypes) # pylint: disable=protected-access
def testAllowF64TruePreferF32False(self): np_dtypes.set_allow_float64(True) np_dtypes.set_prefer_float32(False) self.assertEqual(dtypes.float64, np_dtypes.default_float_type()) self.assertEqual(dtypes.float64, np_dtypes._result_type(1.1)) self.assertEqual(dtypes.complex128, np_dtypes._result_type(1.j))
def testAllowF64False(self, prefer_f32): np_dtypes.set_allow_float64(False) np_dtypes.set_prefer_float32(prefer_f32) self.assertEqual(dtypes.float32, np_dtypes.default_float_type()) self.assertEqual(dtypes.float32, np_dtypes._result_type(np.zeros([], np.float64), 1.1))