def enable_numpy_behavior(prefer_float32=False): """Enable NumPy behavior on Tensors. Includes addition of methods, type promotion on operator overloads and support for NumPy-style slicing. Args: prefer_float32: Whether to allow type inference to use float32, or use float64 similar to NumPy. """ ops.enable_numpy_style_type_promotion() ops.enable_numpy_style_slicing() np_math_ops.enable_numpy_methods_on_tensor() np_dtypes.set_prefer_float32(prefer_float32)
def enable_numpy_behavior(prefer_float32=False): """Enable NumPy behavior on Tensors. Enabling NumPy behavior has three effects: * It adds to `tf.Tensor` some common NumPy methods such as `T`, `reshape` and `ravel`. * It changes dtype promotion in `tf.Tensor` operators to be compatible with NumPy. For example, `tf.ones([], tf.int32) + tf.ones([], tf.float32)` used to throw a "dtype incompatible" error, but after this it will return a float64 tensor (obeying NumPy's promotion rules). * It enhances `tf.Tensor`'s indexing capability to be on par with [NumPy's](https://numpy.org/doc/stable/reference/arrays.indexing.html). Args: prefer_float32: Controls whether dtype inference will use float32 for Python floats, or float64 (the default and the NumPy-compatible behavior). """ ops.enable_numpy_style_type_promotion() ops.enable_numpy_style_slicing() np_math_ops.enable_numpy_methods_on_tensor() np_dtypes.set_prefer_float32(prefer_float32)
return y, z with self.assertRaises(TypeError): f(np.asarray([3, 4])) def testIndex(self): @tf.function def f(x): return [0, 1][x] with self.assertRaises(TypeError): f(np.asarray([1])) class VariableTest(InteropTest): def test(self): tf_var = tf.Variable(2.0) value = np.square(tf_var) self.assertIsInstance(value, np.ndarray) self.assertAllClose(4.0, value) with tf.control_dependencies([tf_var.assign_add(value)]): tf_var_value = tf_var.read_value() self.assertAllClose(6.0, tf_var_value) if __name__ == '__main__': ops.enable_numpy_style_type_promotion() np_math_ops.enable_numpy_methods_on_tensor() tf.compat.v1.enable_eager_execution() tf.test.main()