def test_job_id(self, mock_serialize_assets, mock_submit_job):
        # TF 1.x is not supported
        if utils.is_tf_v1():
            with self.assertRaises(RuntimeError):
                client.cloud_fit(
                    self._model,
                    x=self._dataset,
                    validation_data=self._dataset,
                    remote_dir=self._remote_dir,
                    job_spec=self._job_spec,
                    batch_size=1,
                    epochs=2,
                    verbose=3,
                )
            return

        test_job_id = "test_job_id"
        client.cloud_fit(
            self._model,
            x=self._dataset,
            validation_data=self._dataset,
            remote_dir=self._remote_dir,
            job_spec=self._job_spec,
            job_id=test_job_id,
            batch_size=1,
            epochs=2,
            verbose=3,
        )

        kargs, _ = mock_submit_job.call_args
        body, _ = kargs
        self.assertDictContainsSubset({
            "job_id": test_job_id,
        }, body)
    def test_distribution_strategy(self, mock_serialize_assets,
                                   mock_submit_job):
        # TF 1.x is not supported
        if utils.is_tf_v1():
            with self.assertRaises(RuntimeError):
                client.cloud_fit(self._model,
                                 x=self._dataset,
                                 remote_dir=self._remote_dir)
            return

        client.cloud_fit(self._model,
                         x=self._dataset,
                         remote_dir=self._remote_dir)

        kargs, _ = mock_submit_job.call_args
        body, _ = kargs
        self.assertDictContainsSubset(
            {
                "args": [
                    "--remote_dir",
                    self._remote_dir,
                    "--distribution_strategy",
                    MULTI_WORKER_MIRRORED_STRATEGY_NAME,
                ],
            },
            body["trainingInput"],
        )

        client.cloud_fit(
            self._model,
            x=self._dataset,
            remote_dir=self._remote_dir,
            distribution_strategy=MIRRORED_STRATEGY_NAME,
            job_spec=self._job_spec,
        )

        kargs, _ = mock_submit_job.call_args
        body, _ = kargs
        self.assertDictContainsSubset(
            {
                "args": [
                    "--remote_dir",
                    self._remote_dir,
                    "--distribution_strategy",
                    MIRRORED_STRATEGY_NAME,
                ],
            },
            body["trainingInput"],
        )

        with self.assertRaises(ValueError):
            client.cloud_fit(
                self._model,
                x=self._dataset,
                remote_dir=self._remote_dir,
                distribution_strategy="not_implemented_strategy",
                job_spec=self._job_spec,
            )
    def test_custom_job_spec(self, mock_submit_job):
        # TF 1.x is not supported
        if utils.is_tf_v1():
            with self.assertRaises(RuntimeError):
                client.cloud_fit(
                    self._model,
                    x=self._dataset,
                    validation_data=self._dataset,
                    remote_dir=self._remote_dir,
                    job_spec=self._job_spec,
                    batch_size=1,
                    epochs=2,
                    verbose=3,
                )
            return

        client.cloud_fit(
            self._model,
            x=self._dataset,
            validation_data=self._dataset,
            remote_dir=self._remote_dir,
            job_spec=self._job_spec,
            batch_size=1,
            epochs=2,
            verbose=3,
        )

        kargs, _ = mock_submit_job.call_args
        body, _ = kargs
        self.assertDictContainsSubset(
            {
                "masterConfig": {
                    "imageUri": self._image_uri,
                },
                "args": [
                    "--remote_dir",
                    self._remote_dir,
                    "--distribution_strategy",
                    MULTI_WORKER_MIRRORED_STRATEGY_NAME,
                ],
            },
            body["trainingInput"],
        )
Exemple #4
0
    def test_client_with_tf_1x_raises_error(self):
        # This test is only applicable to TF 1.x
        if not utils.is_tf_v1():
            return

        x = np.random.random((2, 3))
        y = np.random.randint(0, 2, (2, 2))

        # TF 1.x is not supported, verify proper error is raised for TF 1.x.
        with self.assertRaises(RuntimeError):
            client.cloud_fit(
                self._model(),
                x=x,
                y=y,
                remote_dir="gs://some_test_dir",
                region=self._region,
                project_id=self._project_id,
                image_uri=self._image_uri,
                epochs=2,
            )
    def test_fit_kwargs(self, mock_submit_job):
        # TF 1.x is not supported
        if utils.is_tf_v1():
            with self.assertRaises(RuntimeError):
                client.cloud_fit(
                    self._model,
                    x=self._dataset,
                    validation_data=self._dataset,
                    remote_dir=self._remote_dir,
                    job_spec=self._job_spec,
                    batch_size=1,
                    epochs=2,
                    verbose=3,
                )
            return
        job_id = client.cloud_fit(
            self._model,
            x=self._dataset,
            validation_data=self._dataset,
            remote_dir=self._remote_dir,
            region=self._region,
            project_id=self._project_id,
            image_uri=self._image_uri,
            batch_size=1,
            epochs=2,
            verbose=3,
        )

        kargs, _ = mock_submit_job.call_args
        body, _ = kargs
        self.assertEqual(body["job_id"], job_id)
        remote_dir = body["trainingInput"]["args"][1]

        training_assets_graph = tf.saved_model.load(
            os.path.join(remote_dir, "training_assets"))
        elements = training_assets_graph.fit_kwargs_fn()
        self.assertDictContainsSubset(tfds.as_numpy(elements), {
            "batch_size": 1,
            "epochs": 2,
            "verbose": 3
        })
    def test_serialize_assets(self):
        # TF 1.x is not supported
        if utils.is_tf_v1():
            with self.assertRaises(RuntimeError):
                client.cloud_fit(
                    self._model,
                    x=self._dataset,
                    validation_data=self._dataset,
                    remote_dir=self._remote_dir,
                    job_spec=self._job_spec,
                    batch_size=1,
                    epochs=2,
                    verbose=3,
                )
            return
        tensorboard_callback = tf.keras.callbacks.TensorBoard(
            log_dir=self._remote_dir)
        args = self._scalar_fit_kwargs
        args["callbacks"] = [tensorboard_callback]

        client._serialize_assets(self._remote_dir, self._model, **args)
        self.assertGreaterEqual(
            len(
                tf.io.gfile.listdir(
                    os.path.join(self._remote_dir, "training_assets"))), 1)
        self.assertGreaterEqual(
            len(tf.io.gfile.listdir(os.path.join(self._remote_dir, "model"))),
            1)

        training_assets_graph = tf.saved_model.load(
            os.path.join(self._remote_dir, "training_assets"))

        pickled_callbacks = tfds.as_numpy(training_assets_graph.callbacks_fn())
        unpickled_callbacks = pickle.loads(pickled_callbacks)
        self.assertIsInstance(unpickled_callbacks[0],
                              tf.keras.callbacks.TensorBoard)
Exemple #7
0
    def test_in_memory_data(self):
        # This test should only run in tf 2.x
        if utils.is_tf_v1():
            return

        # Create a folder under remote dir for this test's data
        tmp_folder = str(uuid.uuid4())
        remote_dir = os.path.join(self._remote_dir, tmp_folder)

        # Keep track of test folders created for final clean up
        self._test_folders.append(remote_dir)

        x = np.random.random((2, 3))
        y = np.random.randint(0, 2, (2, 2))

        job_id = client.cloud_fit(
            self._model(),
            x=x,
            y=y,
            remote_dir=remote_dir,
            region=self._region,
            project_id=self._project_id,
            image_uri=self._image_uri,
            job_id="cloud_fit_e2e_test_{}_{}".format(
                _BUILD_ID.replace("-", "_"), "test_in_memory_data"),
            epochs=2,
        )
        logging.info("test_in_memory_data submitted with job id: %s", job_id)

        # Wait for AIP Training job to finish successfully
        self.assertTrue(
            google_api_client.wait_for_aip_training_job_completion(
                job_id, self._project_id))

        # load model from remote dir
        trained_model = tf.keras.models.load_model(
            os.path.join(remote_dir, "checkpoint"))
        eval_results = trained_model.evaluate(x, y)

        # Accuracy should be better than zero
        self.assertListEqual(trained_model.metrics_names, ["loss", "accuracy"])
        self.assertGreater(eval_results[1], 0)
    def run_trial(self, trial, *fit_args, **fit_kwargs):
        """Evaluates a set of hyperparameter values.

        This method is called during `search` to evaluate a set of
        hyperparameters using AI Platform training.
        Arguments:
            trial: A `Trial` instance that contains the information
              needed to run this trial. `Hyperparameters` can be accessed
              via `trial.hyperparameters`.
            *fit_args: Positional arguments passed by `search`.
            **fit_kwargs: Keyword arguments passed by `search`.
        Raises:
            RuntimeError: If AIP training job fails.
        """

        # Running the training remotely.
        copied_fit_kwargs = copy.copy(fit_kwargs)

        # Handle any callbacks passed to `fit`.
        callbacks = fit_kwargs.pop("callbacks", [])
        callbacks = self._deepcopy_callbacks(callbacks)

        # Note: run_trial does not use `TunerCallback` calls, since
        # training is performed on AI Platform training remotely.

        # Handle TensorBoard/hyperparameter logging here. The TensorBoard
        # logs are used for passing metrics back from remote execution.
        self._add_logging(callbacks, trial)

        # Creating a save_model checkpoint callback with a saved model file path
        # specific to this trial. This is to prevent different trials from
        # overwriting each other.
        self._add_model_checkpoint_callback(
            callbacks, trial.trial_id)

        copied_fit_kwargs["callbacks"] = callbacks
        model = self.hypermodel.build(trial.hyperparameters)

        remote_dir = os.path.join(self.directory, str(trial.trial_id))

        # TODO(b/170687807) Switch from using "{}".format() to f-string
        job_id = "{}_{}".format(self._study_id, trial.trial_id)

        # Create job spec from worker count and config
        job_spec = self._get_job_spec_from_config(job_id)

        tf.get_logger().info("Calling cloud_fit with %s", {
            "model": model,
            "remote_dir": remote_dir,
            "region": self._region,
            "project_id": self._project_id,
            "image_uri": self._container_uri,
            "job_id": job_id,
            "*fit_args": fit_args,
            "job_spec": job_spec,
            "**copied_fit_kwargs": copied_fit_kwargs})

        cloud_fit_client.cloud_fit(
            model=model,
            remote_dir=remote_dir,
            region=self._region,
            project_id=self._project_id,
            image_uri=self._container_uri,
            job_id=job_id,
            job_spec=job_spec,
            *fit_args,
            **copied_fit_kwargs)

        # Create an instance of tensorboard DirectoryWatcher to retrieve the
        # logs for this trial run
        log_path = os.path.join(
            self._get_tensorboard_log_dir(trial.trial_id), "train")

        # Tensorboard log watcher expects the path to exist
        tf.io.gfile.makedirs(log_path)

        # TODO(b/170687807) Switch from using "{}".format() to f-string
        tf.get_logger().info(
            "Retrieving training logs for trial {} from {}".format(
                trial.trial_id, log_path))
        log_reader = tf_utils.get_tensorboard_log_watcher_from_path(log_path)

        training_metrics = _TrainingMetrics([], {})
        epoch = 0

        while google_api_client.is_api_training_job_running(
            job_id, self._project_id):

            time.sleep(_POLLING_INTERVAL_IN_SECONDS)

            # Retrieve available metrics if any
            training_metrics = self._get_remote_training_metrics(
                log_reader, training_metrics.partial_epoch_metrics)

            for epoch_metrics in training_metrics.completed_epoch_metrics:
                # TODO(b/169197272) Validate metrics contain oracle objective
                if epoch_metrics:
                    trial.status = self.oracle.update_trial(
                        trial_id=trial.trial_id,
                        metrics=epoch_metrics,
                        step=epoch)
                    epoch += 1

            if trial.status == "STOPPED":
                google_api_client.stop_aip_training_job(
                    job_id, self._project_id)
                break

        # Ensure the training job has completed successfully.
        if not google_api_client.wait_for_api_training_job_completion(
            job_id, self._project_id):
            raise RuntimeError(
                "AIP Training job failed, see logs for details at "
                "https://console.cloud.google.com/ai-platform/jobs/"
                "{}/charts/cpu?project={}"
                .format(job_id, self._project_id))

        # Retrieve and report any remaining metrics
        training_metrics = self._get_remote_training_metrics(
            log_reader, training_metrics.partial_epoch_metrics)

        for epoch_metrics in training_metrics.completed_epoch_metrics:
            # TODO(b/169197272) Validate metrics contain oracle objective
            # TODO(b/170907612) Support submit partial results to Oracle
            if epoch_metrics:
                self.oracle.update_trial(
                    trial_id=trial.trial_id,
                    metrics=epoch_metrics,
                    step=epoch)
                epoch += 1

        # submit final epoch metrics
        if training_metrics.partial_epoch_metrics:
            self.oracle.update_trial(
                trial_id=trial.trial_id,
                metrics=training_metrics.partial_epoch_metrics,
                step=epoch)