def test_enable_eager_for_tf_1(self, mock_eager_execution, mock_is_tf_v1): mock_is_tf_v1.return_value = False utils.enable_eager_for_tf_1() # Work around as assert_not_called is not supported in test infra for py35 self.assertFalse(mock_eager_execution.called) mock_is_tf_v1.return_value = True utils.enable_eager_for_tf_1() mock_eager_execution.assert_called_once_with()
# limitations under the License. """Tests for cloud_fit.client.""" import os import tempfile from unittest import mock import cloudpickle from googleapiclient import discovery import tensorflow as tf import tensorflow_datasets as tfds from tensorflow_cloud.experimental.cloud_fit import client from tensorflow_cloud.experimental.cloud_fit import utils # Can only export Datasets which were created executing eagerly utils.enable_eager_for_tf_1() MIRRORED_STRATEGY_NAME = utils.MIRRORED_STRATEGY_NAME MULTI_WORKER_MIRRORED_STRATEGY_NAME = utils.MULTI_WORKER_MIRRORED_STRATEGY_NAME class CloudFitClientTest(tf.test.TestCase): def setUp(self): super(CloudFitClientTest, self).setUp() self._image_uri = "gcr.io/some_test_image:latest" self._project_id = "test_project_id" self._region = "test_region" self._mock_apiclient = mock.Mock() self._remote_dir = tempfile.mkdtemp() self._job_spec = client._default_job_spec( self._region,