示例#1
0
def enable_numpy_behavior(prefer_float32=False):
    """Enable NumPy behavior on Tensors.

  Includes addition of methods, type promotion on operator overloads and
  support for NumPy-style slicing.

  Args:
    prefer_float32: Whether to allow type inference to use float32, or use
    float64 similar to NumPy.
  """
    ops.enable_numpy_style_type_promotion()
    ops.enable_numpy_style_slicing()
    np_math_ops.enable_numpy_methods_on_tensor()
    np_dtypes.set_prefer_float32(prefer_float32)
示例#2
0
def enable_numpy_behavior(prefer_float32=False):
    """Enable NumPy behavior on Tensors.

  Enabling NumPy behavior has three effects:
  * It adds to `tf.Tensor` some common NumPy methods such as `T`,
    `reshape` and `ravel`.
  * It changes dtype promotion in `tf.Tensor` operators to be
    compatible with NumPy. For example,
    `tf.ones([], tf.int32) + tf.ones([], tf.float32)` used to throw a
    "dtype incompatible" error, but after this it will return a
    float64 tensor (obeying NumPy's promotion rules).
  * It enhances `tf.Tensor`'s indexing capability to be on par with
    [NumPy's](https://numpy.org/doc/stable/reference/arrays.indexing.html).

  Args:
    prefer_float32: Controls whether dtype inference will use float32
    for Python floats, or float64 (the default and the
    NumPy-compatible behavior).
  """
    ops.enable_numpy_style_type_promotion()
    ops.enable_numpy_style_slicing()
    np_math_ops.enable_numpy_methods_on_tensor()
    np_dtypes.set_prefer_float32(prefer_float32)
示例#3
0
            return y, z

        with self.assertRaises(TypeError):
            f(np.asarray([3, 4]))

    def testIndex(self):
        @tf.function
        def f(x):
            return [0, 1][x]

        with self.assertRaises(TypeError):
            f(np.asarray([1]))


class VariableTest(InteropTest):
    def test(self):
        tf_var = tf.Variable(2.0)
        value = np.square(tf_var)
        self.assertIsInstance(value, np.ndarray)
        self.assertAllClose(4.0, value)
        with tf.control_dependencies([tf_var.assign_add(value)]):
            tf_var_value = tf_var.read_value()
        self.assertAllClose(6.0, tf_var_value)


if __name__ == '__main__':
    ops.enable_numpy_style_type_promotion()
    np_math_ops.enable_numpy_methods_on_tensor()
    tf.compat.v1.enable_eager_execution()
    tf.test.main()