Example #1
0
 def setUp(self):
     super(CloudFitRemoteTest, self).setUp()
     self._image_uri = "gcr.io/some_test_image:latest"
     self._project_id = "test_project_id"
     self._remote_dir = tempfile.mkdtemp()
     self._output_dir = os.path.join(self._remote_dir, "checkpoint")
     self._x = np.random.random(10)
     self._y = np.random.random(10)
     self._model = self._model()
     self._logs_dir = os.path.join(self._remote_dir, "logs")
     self._fit_kwargs = {
         "x": self._x,
         "y": self._y,
         "verbose": 2,
         "batch_size": 2,
         "epochs": 10,
         "callbacks":
         [tf.keras.callbacks.TensorBoard(log_dir=self._logs_dir)],
     }
     client._serialize_assets(self._remote_dir, self._model,
                              **self._fit_kwargs)
     os.environ["TF_CONFIG"] = json.dumps({
         "cluster": {
             "worker": ["localhost:9999", "localhost:9999"]
         },
         "task": {
             "type": "worker",
             "index": 0
         },
     })
Example #2
0
    def test_custom_callback(self):
        # TF 1.x is not supported
        if utils.is_tf_v1():
            return

        # Setting up custom callback with mock calls
        _MockCallable.reset()

        self._fit_kwargs["callbacks"] = [CustomCallbackExample()]
        client._serialize_assets(self._remote_dir, self._model,
                                 **self._fit_kwargs)

        # Verify callback function has not been called yet.
        _MockCallable.mock_callable.assert_not_called()

        remote.run(self._remote_dir, MIRRORED_STRATEGY_NAME)
        # Verifying callback functions triggered properly
        _MockCallable.mock_callable.assert_called_once_with()
    def test_serialize_assets(self):
        # TF 1.x is not supported
        if utils.is_tf_v1():
            with self.assertRaises(RuntimeError):
                client.cloud_fit(
                    self._model,
                    x=self._dataset,
                    validation_data=self._dataset,
                    remote_dir=self._remote_dir,
                    job_spec=self._job_spec,
                    batch_size=1,
                    epochs=2,
                    verbose=3,
                )
            return
        tensorboard_callback = tf.keras.callbacks.TensorBoard(
            log_dir=self._remote_dir)
        args = self._scalar_fit_kwargs
        args["callbacks"] = [tensorboard_callback]

        client._serialize_assets(self._remote_dir, self._model, **args)
        self.assertGreaterEqual(
            len(
                tf.io.gfile.listdir(
                    os.path.join(self._remote_dir, "training_assets"))), 1)
        self.assertGreaterEqual(
            len(tf.io.gfile.listdir(os.path.join(self._remote_dir, "model"))),
            1)

        training_assets_graph = tf.saved_model.load(
            os.path.join(self._remote_dir, "training_assets"))

        pickled_callbacks = tfds.as_numpy(training_assets_graph.callbacks_fn())
        unpickled_callbacks = pickle.loads(pickled_callbacks)
        self.assertIsInstance(unpickled_callbacks[0],
                              tf.keras.callbacks.TensorBoard)