Exemplo n.º 1
0
 def test_is_aip_training_job_running_with_running_job(self):
     self.mock_request.execute.side_effect = [{
         "state": "QUEUED"
     }, {
         "state": "PREPARING"
     }, {
         "state": "RUNNING"
     }, {
         "state": "CANCELLING"
     }]
     queued_status = google_api_client.is_aip_training_job_running(
         self._job_id, self._project_id)
     self.assertTrue(queued_status)
     job_name = "projects/{}/jobs/{}".format(self._project_id, self._job_id)
     self.mock_apiclient.projects().jobs().get.assert_called_with(
         name=job_name)
     preparing_status = google_api_client.is_aip_training_job_running(
         self._job_id, self._project_id)
     self.assertTrue(preparing_status)
     running_status = google_api_client.is_aip_training_job_running(
         self._job_id, self._project_id)
     self.assertTrue(running_status)
     canceling_status = google_api_client.is_aip_training_job_running(
         self._job_id, self._project_id)
     self.assertTrue(canceling_status)
     self.assertEqual(4, self.mock_request.execute.call_count)
Exemplo n.º 2
0
 def test_is_aip_training_job_running_with_completed_job(self):
     self.mock_request.execute.side_effect = [
         {"state": "SUCCEEDED"},
         {"state": "CANCELLED"},
         {"state": "FAILED", "errorMessage": "test_error_message"}]
     succeeded_status = google_api_client.is_aip_training_job_running(
         self._job_id, self._project_id)
     self.assertFalse(succeeded_status)
     job_name = "projects/{}/jobs/{}".format(self._project_id, self._job_id)
     self.mock_apiclient.projects().jobs().get.assert_called_with(
         name=job_name)
     cancelled_status = google_api_client.is_aip_training_job_running(
         self._job_id, self._project_id)
     self.assertFalse(cancelled_status)
     failed_status = google_api_client.is_aip_training_job_running(
         self._job_id, self._project_id)
     self.assertFalse(failed_status)
     self.assertEqual(3, self.mock_request.execute.call_count)
Exemplo n.º 3
0
    def run_trial(self, trial, *fit_args, **fit_kwargs):
        """Evaluates a set of hyperparameter values.

        This method is called during `search` to evaluate a set of
        hyperparameters using AI Platform training.
        Arguments:
            trial: A `Trial` instance that contains the information
              needed to run this trial. `Hyperparameters` can be accessed
              via `trial.hyperparameters`.
            *fit_args: Positional arguments passed by `search`.
            **fit_kwargs: Keyword arguments passed by `search`.
        Raises:
            RuntimeError: If AIP training job fails.
        """

        # Running the training remotely.
        copied_fit_kwargs = copy.copy(fit_kwargs)

        # Handle any callbacks passed to `fit`.
        callbacks = fit_kwargs.pop("callbacks", [])
        callbacks = self._deepcopy_callbacks(callbacks)

        # Note: run_trial does not use `TunerCallback` calls, since
        # training is performed on AI Platform training remotely.

        # Handle TensorBoard/hyperparameter logging here. The TensorBoard
        # logs are used for passing metrics back from remote execution.
        self._add_logging(callbacks, trial)

        # Creating a save_model checkpoint callback with a saved model file path
        # specific to this trial. This is to prevent different trials from
        # overwriting each other.
        self._add_model_checkpoint_callback(callbacks, trial.trial_id)

        copied_fit_kwargs["callbacks"] = callbacks
        model = self.hypermodel.build(trial.hyperparameters)

        remote_dir = os.path.join(self.directory, str(trial.trial_id))

        job_id = f"{self._study_id}_{trial.trial_id}"

        # Create job spec from worker count and config
        job_spec = self._get_job_spec_from_config(job_id)

        tf.get_logger().info(
            "Calling cloud_fit with %s", {
                "model": model,
                "remote_dir": remote_dir,
                "region": self._region,
                "project_id": self._project_id,
                "image_uri": self._container_uri,
                "job_id": job_id,
                "*fit_args": fit_args,
                "job_spec": job_spec,
                "**copied_fit_kwargs": copied_fit_kwargs
            })

        cloud_fit_client.cloud_fit(model=model,
                                   remote_dir=remote_dir,
                                   region=self._region,
                                   project_id=self._project_id,
                                   image_uri=self._container_uri,
                                   job_id=job_id,
                                   job_spec=job_spec,
                                   *fit_args,
                                   **copied_fit_kwargs)

        # Create an instance of tensorboard DirectoryWatcher to retrieve the
        # logs for this trial run
        log_path = os.path.join(self._get_tensorboard_log_dir(trial.trial_id),
                                "train")

        # Tensorboard log watcher expects the path to exist
        tf.io.gfile.makedirs(log_path)

        tf.get_logger().info(
            f"Retrieving training logs for trial {trial.trial_id} from"
            f" {log_path}")
        log_reader = tf_utils.get_tensorboard_log_watcher_from_path(log_path)

        training_metrics = _TrainingMetrics([], {})
        epoch = 0

        while google_api_client.is_aip_training_job_running(
                job_id, self._project_id):

            time.sleep(_POLLING_INTERVAL_IN_SECONDS)

            # Retrieve available metrics if any
            training_metrics = self._get_remote_training_metrics(
                log_reader, training_metrics.partial_epoch_metrics)

            for epoch_metrics in training_metrics.completed_epoch_metrics:
                # TODO(b/169197272) Validate metrics contain oracle objective
                if epoch_metrics:
                    trial.status = self.oracle.update_trial(
                        trial_id=trial.trial_id,
                        metrics=epoch_metrics,
                        step=epoch)
                    epoch += 1

            if trial.status == "STOPPED":
                google_api_client.stop_aip_training_job(
                    job_id, self._project_id)
                break

        # Ensure the training job has completed successfully.
        if not google_api_client.wait_for_aip_training_job_completion(
                job_id, self._project_id):
            raise RuntimeError(
                "AIP Training job failed, see logs for details at "
                "https://console.cloud.google.com/ai-platform/jobs/"
                "{}/charts/cpu?project={}".format(job_id, self._project_id))

        # Retrieve and report any remaining metrics
        training_metrics = self._get_remote_training_metrics(
            log_reader, training_metrics.partial_epoch_metrics)

        for epoch_metrics in training_metrics.completed_epoch_metrics:
            # TODO(b/169197272) Validate metrics contain oracle objective
            # TODO(b/170907612) Support submit partial results to Oracle
            if epoch_metrics:
                self.oracle.update_trial(trial_id=trial.trial_id,
                                         metrics=epoch_metrics,
                                         step=epoch)
                epoch += 1

        # submit final epoch metrics
        if training_metrics.partial_epoch_metrics:
            self.oracle.update_trial(
                trial_id=trial.trial_id,
                metrics=training_metrics.partial_epoch_metrics,
                step=epoch)