コード例 #1
0
    def test_run_on_script(self):
        track_status = {
            "auto_mirrored_strategy":
            self.auto_mirrored_strategy(),
            "auto_tpu_strategy":
            self.auto_tpu_strategy(),
            "auto_one_device_strategy":
            self.auto_one_device_strategy(),
            "auto_one_device_strategy_cloud_build":
            self.auto_one_device_strategy_cloud_build(),
            "auto_multi_worker_strategy":
            self.auto_multi_worker_strategy(),
            "none_dist_strat_multi_worker_strategy":
            self.none_dist_strat_multi_worker_strategy(),
            "auto_dist_strat_mwms_custom_img":
            self.auto_dist_strat_mwms_custom_img(),
            "auto_one_device_job_labels":
            self.auto_one_device_job_labels(),
        }

        for test_name, job_id in track_status.items():
            self.assertTrue(
                google_api_client.wait_for_api_training_job_success(
                    job_id, _PROJECT_ID),
                "Job {} generated from the test: {} has failed".format(
                    job_id, test_name))
コード例 #2
0
 def test_wait_for_api_training_job_success_multiple_checks_success(self):
     self.mock_request.execute.side_effect = [{
         "state": "PREPARING"
     }, {
         "state": "RUNNING"
     }, {
         "state": "SUCCEEDED"
     }]
     status = google_api_client.wait_for_api_training_job_success(
         self._job_id, self._project_id)
     self.assertTrue(status)
     self.assertEqual(3, self.mock_request.execute.call_count)
コード例 #3
0
 def test_wait_for_api_training_job_success_non_blocking_failed(
         self, mock_log_error):
     self.mock_request.execute.return_value = {
         "state": "FAILED",
         "errorMessage": "test_error_message"
     }
     status = google_api_client.wait_for_api_training_job_success(
         self._job_id, self._project_id)
     self.assertFalse(status)
     self.mock_request.execute.assert_called_once()
     mock_log_error.assert_called_once_with(
         "AIP Training job %s failed with error %s.", self._job_id,
         "test_error_message")
コード例 #4
0
 def test_wait_for_api_training_job_success_non_blocking_success(
         self, mock_log_error):
     self.mock_request.execute.return_value = {
         "state": "SUCCEEDED",
     }
     status = google_api_client.wait_for_api_training_job_success(
         self._job_id, self._project_id)
     self.assertTrue(status)
     self.mock_request.execute.assert_called_once()
     job_name = "projects/{}/jobs/{}".format(self._project_id, self._job_id)
     self.mock_apiclient.projects().jobs().get.assert_called_with(
         name=job_name)
     mock_log_error.assert_not_called()
コード例 #5
0
 def test_wait_for_api_training_job_success_multiple_checks_failed(
         self, mock_log_error):
     self.mock_request.execute.side_effect = [{
         "state": "PREPARING"
     }, {
         "state": "RUNNING"
     }, {
         "state":
         "FAILED",
         "errorMessage":
         "test_error_message"
     }]
     status = google_api_client.wait_for_api_training_job_success(
         self._job_id, self._project_id)
     self.assertFalse(status)
     self.assertEqual(3, self.mock_request.execute.call_count)
     mock_log_error.assert_called_once_with(
         "AIP Training job %s failed with error %s.", self._job_id,
         "test_error_message")
コード例 #6
0
    def test_in_memory_data(self):
        # This test should only run in tf 2.x
        if utils.is_tf_v1():
            return

        # Create a folder under remote dir for this test's data
        tmp_folder = str(uuid.uuid4())
        remote_dir = os.path.join(self._remote_dir, tmp_folder)

        # Keep track of test folders created for final clean up
        self._test_folders.append(remote_dir)

        x = np.random.random((2, 3))
        y = np.random.randint(0, 2, (2, 2))

        job_id = client.cloud_fit(
            self._model(),
            x=x,
            y=y,
            remote_dir=remote_dir,
            region=self._region,
            project_id=self._project_id,
            image_uri=self._image_uri,
            job_id="cloud_fit_e2e_test_{}_{}".format(
                _BUILD_ID.replace("-", "_"), "test_in_memory_data"),
            epochs=2,
        )

        # TODO(b/169297404) Replace AIP job status logic with utils wrapper
        # Wait for AIP Training job to finish successfully
        self.assertTrue(
            google_api_client.wait_for_api_training_job_success(
                job_id, self._project_id))

        # load model from remote dir
        trained_model = tf.keras.models.load_model(
            os.path.join(remote_dir, "output"))
        eval_results = trained_model.evaluate(x, y)

        # Accuracy should be better than zero
        self.assertListEqual(trained_model.metrics_names, ["loss", "accuracy"])
        self.assertGreater(eval_results[1], 0)
コード例 #7
0
    def run_trial(self, trial, *fit_args, **fit_kwargs):
        """Evaluates a set of hyperparameter values.

        This method is called during `search` to evaluate a set of
        hyperparameters using AI Platform training.
        Arguments:
            trial: A `Trial` instance that contains the information
              needed to run this trial. `Hyperparameters` can be accessed
              via `trial.hyperparameters`.
            *fit_args: Positional arguments passed by `search`.
            **fit_kwargs: Keyword arguments passed by `search`.
        Raises:
            RuntimeError: If AIP training job fails.
        """

        # Running the training remotely.
        copied_fit_kwargs = copy.copy(fit_kwargs)

        # Handle any callbacks passed to `fit`.
        callbacks = fit_kwargs.pop("callbacks", [])
        callbacks = self._deepcopy_callbacks(callbacks)

        # Note run_trial does not use `TunerCallback` calls, since
        # training is performed on AI Platform training remotely.

        # Creating a tensorboard callback with log-dir path specific for this
        # trail_id. The tensorboard logs are used for passing metrics back from
        # remote execution.
        self._add_tensorboard_callback(callbacks, trial.trial_id)

        # Creating a save_model checkpoint callback with a saved model file path
        # specific to this trial, this is to prevent different trials from
        # overwriting each other.
        self._add_model_checkpoint_callback(
            callbacks, trial.trial_id)

        copied_fit_kwargs["callbacks"] = callbacks
        model = self.hypermodel.build(trial.hyperparameters)
        job_id = "{}_{}".format(self._study_id, trial.trial_id)

        tf.get_logger().info("Calling cloud_fit with %s", {
            "model": model,
            "remote_dir": self.directory,
            "region": self._region,
            "project_id": self._project_id,
            "image_uri": self.container_uri,
            "job_id": job_id,
            "*fit_args": fit_args,
            "**copied_fit_kwargs": copied_fit_kwargs})

        client.cloud_fit(
            model=model,
            remote_dir=self.directory,
            region=self._region,
            project_id=self._project_id,
            image_uri=self.container_uri,
            job_id=job_id,
            *fit_args,
            **copied_fit_kwargs)

        # TODO(b/167569957) Add support for early termination.
        if not google_api_client.wait_for_api_training_job_success(
            job_id, self._project_id):
            raise RuntimeError(
                "AIP Training job failed, see logs for details at https://console.cloud.google.com/ai-platform/jobs/{}/charts/cpu?project={}"  # pylint: disable=line-too-long
                .format(job_id, self._project_id))

        # If the job was successful, retrieve the metrics
        training_metrics = self._get_remote_training_metrics(trial.trial_id)

        # Note since we are submitting all job results in one shot, this may
        # result in going over AI Platform Vizier limit of 1000 RPS. For more
        # details on API quotas refer to:
        # https://cloud.google.com/ai-platform/optimizer/docs/overview
        for epoch, epoch_metrics in enumerate(training_metrics):
        # TODO(b/169197272) Validate metrics contain oracle objective
            self.oracle.update_trial(
                trial_id=trial.trial_id,
                metrics=epoch_metrics,
                step=epoch)