Ejemplo n.º 1
0
def main(enable_v2_behavior=True):
    """All-in-one main function for tf.distribute tests."""
    if enable_v2_behavior:
        v2_compat.enable_v2_behavior()
    else:
        v2_compat.disable_v2_behavior()
    # TODO(b/131360402): configure default logical devices.
    multi_process_runner.test_main()
Ejemplo n.º 2
0
def main(enable_v2_behavior=True, config_logical_devices=True):
    """All-in-one main function for tf.distribute tests."""
    if config_logical_devices:
        app.call_after_init(_set_logical_devices)
    if enable_v2_behavior:
        v2_compat.enable_v2_behavior()
    else:
        v2_compat.disable_v2_behavior()
    multi_process_runner.test_main()
Ejemplo n.º 3
0
 def test_basic(self):
     t = constant_op.constant([1, 2, 3])  # creates a hidden context
     self.assertTrue(isinstance(t, ops.EagerTensor))
     t = _pywrap_tf2.is_enabled()
     self.assertTrue(t)
     v2_compat.disable_v2_behavior()
     t = constant_op.constant([1, 2, 3])
     self.assertFalse(isinstance(t, ops.EagerTensor))
     t = _pywrap_tf2.is_enabled()
     self.assertFalse(t)
Ejemplo n.º 4
0
    def test_tf2_disable_tf2_behaviour(self):
        self.assertTrue(tf2.enabled())
        self.assertTrue(_pywrap_tf2.is_enabled())

        v2_compat.disable_v2_behavior()
        self.assertFalse(tf2.enabled())
        self.assertFalse(_pywrap_tf2.is_enabled())

        v2_compat.enable_v2_behavior()
        self.assertTrue(tf2.enabled())
        self.assertTrue(_pywrap_tf2.is_enabled())
Ejemplo n.º 5
0
from absl import flags
from absl import logging
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.compat import v2_compat

from tensorflow.python.framework import function

# The following imports are needed to expose private _Send/_Recv ops
# on TensorFlow 1.X. The could be removed once support for 1.X is dropped.
from google.protobuf import text_format as _text_format
from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
from tensorflow.python.framework import op_def_library as _op_def_library
from tensorflow.python.framework import op_def_registry as _op_def_registry
# pylint: enable=g-direct-tensorflow-import

v2_compat.disable_v2_behavior()
Defun = function.Defun


# TODO(slebedev): Remove after there is no need to support 1.X.
def _InitOpDefLibrary():
    op_list = _op_def_pb2.OpList()
    _text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list)
    _op_def_registry.register_op_list(op_list)
    op_def_lib = _op_def_library.OpDefLibrary()
    op_def_lib.add_op_list(op_list)
    return op_def_lib


_InitOpDefLibrary.op_list_ascii = """\
op {
Ejemplo n.º 6
0
 def test_basic(self):
   t = constant_op.constant([1, 2, 3])  # creates a hidden context
   self.assertTrue(isinstance(t, ops.EagerTensor))
   v2_compat.disable_v2_behavior()
   t = constant_op.constant([1, 2, 3])
   self.assertFalse(isinstance(t, ops.EagerTensor))