Esempio n. 1
0
 def testAllowF64TruePreferF32True(self):
     np_dtypes.set_allow_float64(True)
     np_dtypes.set_prefer_float32(True)
     self.assertEqual(dtypes.float32, np_dtypes.default_float_type())
     self.assertEqual(dtypes.float32, np_dtypes._result_type(1.1))
     self.assertEqual(dtypes.float64,
                      np_dtypes._result_type(np.zeros([], np.float64), 1.1))
Esempio n. 2
0
def enable_numpy_behavior(prefer_float32=False):
    """Enable NumPy behavior on Tensors.

  Includes addition of methods, type promotion on operator overloads and
  support for NumPy-style slicing.

  Args:
    prefer_float32: Whether to allow type inference to use float32, or use
    float64 similar to NumPy.
  """
    ops.enable_numpy_style_type_promotion()
    ops.enable_numpy_style_slicing()
    np_math_ops.enable_numpy_methods_on_tensor()
    np_dtypes.set_allow_float64(not prefer_float32)
Esempio n. 3
0
 def _test(self, *args, **kw_args):
     onp_dtype = kw_args.pop('onp_dtype', None)
     allow_float64 = kw_args.pop('allow_float64', True)
     old_allow_float64 = np_dtypes.is_allow_float64()
     np_dtypes.set_allow_float64(allow_float64)
     old_func = getattr(self, 'onp_func', None)
     # TODO(agarwal): Note that onp can return a scalar type while np returns
     # ndarrays. Currently np does not support scalar types.
     self.onp_func = lambda *args, **kwargs: onp.asarray(  # pylint: disable=g-long-lambda
         old_func(*args, **kwargs))
     np_out = self.np_func(*args, **kw_args)
     onp_out = onp.asarray(self.onp_func(*args, **kw_args))
     if onp_dtype is not None:
         onp_out = onp_out.astype(onp_dtype)
     self.assertEqual(np_out.shape, onp_out.shape)
     self.assertEqual(np_out.dtype, onp_out.dtype)
     np_dtypes.set_allow_float64(old_allow_float64)
 def testAllowF64TruePreferF32False(self):
     np_dtypes.set_allow_float64(True)
     np_dtypes.set_prefer_float32(False)
     self.assertEqual(dtypes.float64, np_dtypes.default_float_type())
     self.assertEqual(dtypes.float64, np_dtypes._result_type(1.1))
     self.assertEqual(dtypes.complex128, np_dtypes._result_type(1.j))
 def testAllowF64False(self, prefer_f32):
     np_dtypes.set_allow_float64(False)
     np_dtypes.set_prefer_float32(prefer_f32)
     self.assertEqual(dtypes.float32, np_dtypes.default_float_type())
     self.assertEqual(dtypes.float32,
                      np_dtypes._result_type(np.zeros([], np.float64), 1.1))