def main(enable_v2_behavior=True): """All-in-one main function for tf.distribute tests.""" if enable_v2_behavior: v2_compat.enable_v2_behavior() else: v2_compat.disable_v2_behavior() # TODO(b/131360402): configure default logical devices. multi_process_runner.test_main()
def main(enable_v2_behavior=True, config_logical_devices=True): """All-in-one main function for tf.distribute tests.""" if config_logical_devices: app.call_after_init(_set_logical_devices) if enable_v2_behavior: v2_compat.enable_v2_behavior() else: v2_compat.disable_v2_behavior() multi_process_runner.test_main()
def test_basic(self): t = constant_op.constant([1, 2, 3]) # creates a hidden context self.assertTrue(isinstance(t, ops.EagerTensor)) t = _pywrap_tf2.is_enabled() self.assertTrue(t) v2_compat.disable_v2_behavior() t = constant_op.constant([1, 2, 3]) self.assertFalse(isinstance(t, ops.EagerTensor)) t = _pywrap_tf2.is_enabled() self.assertFalse(t)
def test_tf2_disable_tf2_behaviour(self): self.assertTrue(tf2.enabled()) self.assertTrue(_pywrap_tf2.is_enabled()) v2_compat.disable_v2_behavior() self.assertFalse(tf2.enabled()) self.assertFalse(_pywrap_tf2.is_enabled()) v2_compat.enable_v2_behavior() self.assertTrue(tf2.enabled()) self.assertTrue(_pywrap_tf2.is_enabled())
from absl import flags from absl import logging # pylint: disable=g-direct-tensorflow-import from tensorflow.python.compat import v2_compat from tensorflow.python.framework import function # The following imports are needed to expose private _Send/_Recv ops # on TensorFlow 1.X. The could be removed once support for 1.X is dropped. from google.protobuf import text_format as _text_format from tensorflow.core.framework import op_def_pb2 as _op_def_pb2 from tensorflow.python.framework import op_def_library as _op_def_library from tensorflow.python.framework import op_def_registry as _op_def_registry # pylint: enable=g-direct-tensorflow-import v2_compat.disable_v2_behavior() Defun = function.Defun # TODO(slebedev): Remove after there is no need to support 1.X. def _InitOpDefLibrary(): op_list = _op_def_pb2.OpList() _text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list) _op_def_registry.register_op_list(op_list) op_def_lib = _op_def_library.OpDefLibrary() op_def_lib.add_op_list(op_list) return op_def_lib _InitOpDefLibrary.op_list_ascii = """\ op {
def test_basic(self): t = constant_op.constant([1, 2, 3]) # creates a hidden context self.assertTrue(isinstance(t, ops.EagerTensor)) v2_compat.disable_v2_behavior() t = constant_op.constant([1, 2, 3]) self.assertFalse(isinstance(t, ops.EagerTensor))