예제 #1
0
 def testAllowF64TruePreferF32True(self):
     np_dtypes.set_allow_float64(True)
     np_dtypes.set_prefer_float32(True)
     self.assertEqual(dtypes.float32, np_dtypes.default_float_type())
     self.assertEqual(dtypes.float32, np_dtypes._result_type(1.1))
     self.assertEqual(dtypes.float64,
                      np_dtypes._result_type(np.zeros([], np.float64), 1.1))
예제 #2
0
def result_type(*arrays_and_dtypes):
  """Returns the type resulting from applying NumPy type promotion to arguments.

  Args:
    *arrays_and_dtypes: A list of array_like objects or dtypes.

  Returns:
    A numpy dtype.
  """

  def maybe_get_dtype(x):
    # Don't put np.ndarray in this list, because np.result_type looks at the
    # value (not just dtype) of np.ndarray to decide the result type.
    if isinstance(x, (np_arrays.ndarray, ops.Tensor,
                      indexed_slices.IndexedSlices)):
      return _to_numpy_type(x.dtype)
    elif isinstance(x, dtypes.DType):
      return _to_numpy_type(x)
    return x

  arrays_and_dtypes = [
      maybe_get_dtype(x) for x in nest.flatten(arrays_and_dtypes)
  ]
  if not arrays_and_dtypes:
    # If arrays_and_dtypes is an empty list, let numpy decide what the dtype is.
    arrays_and_dtypes = [np.asarray([])]
  return np_dtypes._result_type(*arrays_and_dtypes)  # pylint: disable=protected-access
예제 #3
0
def _result_type_binary(t1, t2):  # pylint: disable=missing-function-docstring
  """A specialization of result_type for 2 arguments for performance reasons."""
  try:
    return np_dtypes._result_type(_maybe_get_dtype(t1),  # pylint: disable=protected-access
                                  _maybe_get_dtype(t2))  # pylint: disable=protected-access
  except ValueError:
    return result_type(t1, t2)
예제 #4
0
def result_type(*arrays_and_dtypes):  # pylint: disable=missing-function-docstring
    arrays_and_dtypes = [
        _maybe_get_dtype(x) for x in nest.flatten(arrays_and_dtypes)
    ]
    if not arrays_and_dtypes:
        # If arrays_and_dtypes is an empty list, let numpy decide what the dtype is.
        arrays_and_dtypes = [np.asarray([])]
    return np_dtypes._result_type(*arrays_and_dtypes)  # pylint: disable=protected-access
예제 #5
0
def result_type(*arrays_and_dtypes):  # pylint: disable=missing-function-docstring
    def maybe_get_dtype(x):
        # Don't put np.ndarray in this list, because np.result_type looks at the
        # value (not just dtype) of np.ndarray to decide the result type.
        if isinstance(
                x,
            (np_arrays.ndarray, core.Tensor, indexed_slices.IndexedSlices)):
            return _to_numpy_type(x.dtype)
        elif isinstance(x, dtypes.DType):
            return _to_numpy_type(x)
        return x

    arrays_and_dtypes = [
        maybe_get_dtype(x) for x in nest.flatten(arrays_and_dtypes)
    ]
    if not arrays_and_dtypes:
        # If arrays_and_dtypes is an empty list, let numpy decide what the dtype is.
        arrays_and_dtypes = [np.asarray([])]
    return np_dtypes._result_type(*arrays_and_dtypes)  # pylint: disable=protected-access
예제 #6
0
 def testAllowF64TruePreferF32False(self):
     np_dtypes.set_allow_float64(True)
     np_dtypes.set_prefer_float32(False)
     self.assertEqual(dtypes.float64, np_dtypes.default_float_type())
     self.assertEqual(dtypes.float64, np_dtypes._result_type(1.1))
     self.assertEqual(dtypes.complex128, np_dtypes._result_type(1.j))
예제 #7
0
 def testAllowF64False(self, prefer_f32):
     np_dtypes.set_allow_float64(False)
     np_dtypes.set_prefer_float32(prefer_f32)
     self.assertEqual(dtypes.float32, np_dtypes.default_float_type())
     self.assertEqual(dtypes.float32,
                      np_dtypes._result_type(np.zeros([], np.float64), 1.1))