Exemple #1
0
    def test_enable_eager_for_tf_1(self, mock_eager_execution, mock_is_tf_v1):
        mock_is_tf_v1.return_value = False
        utils.enable_eager_for_tf_1()
        # Work around as assert_not_called is not supported in test infra for py35
        self.assertFalse(mock_eager_execution.called)

        mock_is_tf_v1.return_value = True
        utils.enable_eager_for_tf_1()
        mock_eager_execution.assert_called_once_with()
Exemple #2
0
# limitations under the License.
"""Tests for cloud_fit.client."""

import os
import tempfile
from unittest import mock
import cloudpickle
from googleapiclient import discovery
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow_cloud.experimental.cloud_fit import client
from tensorflow_cloud.experimental.cloud_fit import utils

# Can only export Datasets which were created executing eagerly
utils.enable_eager_for_tf_1()

MIRRORED_STRATEGY_NAME = utils.MIRRORED_STRATEGY_NAME
MULTI_WORKER_MIRRORED_STRATEGY_NAME = utils.MULTI_WORKER_MIRRORED_STRATEGY_NAME


class CloudFitClientTest(tf.test.TestCase):
    def setUp(self):
        super(CloudFitClientTest, self).setUp()
        self._image_uri = "gcr.io/some_test_image:latest"
        self._project_id = "test_project_id"
        self._region = "test_region"
        self._mock_apiclient = mock.Mock()
        self._remote_dir = tempfile.mkdtemp()
        self._job_spec = client._default_job_spec(
            self._region,