def test_wait_for_aip_training_job_completion_non_blocking_failed(self): self.mock_request.execute.return_value = { "state": "FAILED", "errorMessage": "test_error_message"} status = google_api_client.wait_for_aip_training_job_completion( self._job_id, self._project_id) self.assertFalse(status) self.mock_request.execute.assert_called_once()
def test_run_on_script(self): track_status = { # TODO(b/172668718) Enable tests after b/172668718 is resolved. # "auto_mirrored_strategy": self.auto_mirrored_strategy(), "auto_tpu_strategy": self.auto_tpu_strategy(), # "auto_one_device_strategy": self.auto_one_device_strategy(), "auto_multi_worker_strategy": self.auto_multi_worker_strategy(), # "none_dist_strat": self.none_dist_strat(), # "docker_config_cloud_build": self.docker_config_cloud_build(), "docker_config_parent_img": self.docker_config_parent_img(), # "docker_config_image": self.docker_config_image(), # "docker_config_cache_from": self.docker_config_cache_from(), # "job_labels": self.job_labels(), "cloud_build_base_image_backward_compatibility": self.cloud_build_base_image_backward_compatibility(), } for test_name, ret_val in track_status.items(): self.assertTrue( google_api_client.wait_for_aip_training_job_completion( ret_val["job_id"], _PROJECT_ID), "Job {} generated from the test: {} has failed".format( ret_val["job_id"], test_name))
def test_wait_for_aip_training_job_completion_multiple_checks_failed(self): self.mock_request.execute.side_effect = [ {"state": "PREPARING"}, {"state": "RUNNING"}, {"state": "FAILED", "errorMessage": "test_error_message"}] status = google_api_client.wait_for_aip_training_job_completion( self._job_id, self._project_id) self.assertFalse(status) self.assertEqual(3, self.mock_request.execute.call_count)
def test_wait_for_aip_training_job_completion_multiple_checks_success(self): self.mock_request.execute.side_effect = [ {"state": "PREPARING"}, {"state": "RUNNING"}, {"state": "SUCCEEDED"} ] status = google_api_client.wait_for_aip_training_job_completion( self._job_id, self._project_id) self.assertTrue(status) self.assertEqual(3, self.mock_request.execute.call_count)
def test_wait_for_aip_training_job_completion_non_blocking_cancelled(self): self.mock_request.execute.return_value = { "state": "CANCELLED", } status = google_api_client.wait_for_aip_training_job_completion( self._job_id, self._project_id) self.assertTrue(status) self.mock_request.execute.assert_called_once() job_name = "projects/{}/jobs/{}".format(self._project_id, self._job_id) self.mock_apiclient.projects().jobs().get.assert_called_with( name=job_name)
def test_run_on_notebook(self): track_status = { "auto_mirrored_strategy": self.auto_mirrored_strategy(), "auto_tpu_strategy": self.auto_tpu_strategy(), "auto_one_device_strategy": self.auto_one_device_strategy(), "auto_multi_worker_strategy": self.auto_multi_worker_strategy(), "docker_config_cloud_build": self.docker_config_cloud_build(), "docker_config_parent_img": self.docker_config_parent_img(), "docker_config_image": self.docker_config_image(), "docker_config_cache_from": self.docker_config_cache_from(), } for test_name, ret_val in track_status.items(): self.assertTrue( google_api_client.wait_for_aip_training_job_completion( ret_val["job_id"], _PROJECT_ID), "Job {} generated from the test: {} has failed".format( ret_val["job_id"], test_name))
def test_in_memory_data(self): # This test should only run in tf 2.x if utils.is_tf_v1(): return # Create a folder under remote dir for this test's data tmp_folder = str(uuid.uuid4()) remote_dir = os.path.join(self._remote_dir, tmp_folder) # Keep track of test folders created for final clean up self._test_folders.append(remote_dir) x = np.random.random((2, 3)) y = np.random.randint(0, 2, (2, 2)) job_id = client.cloud_fit( self._model(), x=x, y=y, remote_dir=remote_dir, region=self._region, project_id=self._project_id, image_uri=self._image_uri, job_id="cloud_fit_e2e_test_{}_{}".format( _BUILD_ID.replace("-", "_"), "test_in_memory_data"), epochs=2, ) logging.info("test_in_memory_data submitted with job id: %s", job_id) # Wait for AIP Training job to finish successfully self.assertTrue( google_api_client.wait_for_aip_training_job_completion( job_id, self._project_id)) # load model from remote dir trained_model = tf.keras.models.load_model( os.path.join(remote_dir, "checkpoint")) eval_results = trained_model.evaluate(x, y) # Accuracy should be better than zero self.assertListEqual(trained_model.metrics_names, ["loss", "accuracy"]) self.assertGreater(eval_results[1], 0)
def run_trial(self, trial, *fit_args, **fit_kwargs): """Evaluates a set of hyperparameter values. This method is called during `search` to evaluate a set of hyperparameters using AI Platform training. Arguments: trial: A `Trial` instance that contains the information needed to run this trial. `Hyperparameters` can be accessed via `trial.hyperparameters`. *fit_args: Positional arguments passed by `search`. **fit_kwargs: Keyword arguments passed by `search`. Raises: RuntimeError: If AIP training job fails. """ # Running the training remotely. copied_fit_kwargs = copy.copy(fit_kwargs) # Handle any callbacks passed to `fit`. callbacks = fit_kwargs.pop("callbacks", []) callbacks = self._deepcopy_callbacks(callbacks) # Note: run_trial does not use `TunerCallback` calls, since # training is performed on AI Platform training remotely. # Handle TensorBoard/hyperparameter logging here. The TensorBoard # logs are used for passing metrics back from remote execution. self._add_logging(callbacks, trial) # Creating a save_model checkpoint callback with a saved model file path # specific to this trial. This is to prevent different trials from # overwriting each other. self._add_model_checkpoint_callback(callbacks, trial.trial_id) copied_fit_kwargs["callbacks"] = callbacks model = self.hypermodel.build(trial.hyperparameters) remote_dir = os.path.join(self.directory, str(trial.trial_id)) job_id = f"{self._study_id}_{trial.trial_id}" # Create job spec from worker count and config job_spec = self._get_job_spec_from_config(job_id) tf.get_logger().info( "Calling cloud_fit with %s", { "model": model, "remote_dir": remote_dir, "region": self._region, "project_id": self._project_id, "image_uri": self._container_uri, "job_id": job_id, "*fit_args": fit_args, "job_spec": job_spec, "**copied_fit_kwargs": copied_fit_kwargs }) cloud_fit_client.cloud_fit(model=model, remote_dir=remote_dir, region=self._region, project_id=self._project_id, image_uri=self._container_uri, job_id=job_id, job_spec=job_spec, *fit_args, **copied_fit_kwargs) # Create an instance of tensorboard DirectoryWatcher to retrieve the # logs for this trial run log_path = os.path.join(self._get_tensorboard_log_dir(trial.trial_id), "train") # Tensorboard log watcher expects the path to exist tf.io.gfile.makedirs(log_path) tf.get_logger().info( f"Retrieving training logs for trial {trial.trial_id} from" f" {log_path}") log_reader = tf_utils.get_tensorboard_log_watcher_from_path(log_path) training_metrics = _TrainingMetrics([], {}) epoch = 0 while google_api_client.is_aip_training_job_running( job_id, self._project_id): time.sleep(_POLLING_INTERVAL_IN_SECONDS) # Retrieve available metrics if any training_metrics = self._get_remote_training_metrics( log_reader, training_metrics.partial_epoch_metrics) for epoch_metrics in training_metrics.completed_epoch_metrics: # TODO(b/169197272) Validate metrics contain oracle objective if epoch_metrics: trial.status = self.oracle.update_trial( trial_id=trial.trial_id, metrics=epoch_metrics, step=epoch) epoch += 1 if trial.status == "STOPPED": google_api_client.stop_aip_training_job( job_id, self._project_id) break # Ensure the training job has completed successfully. if not google_api_client.wait_for_aip_training_job_completion( job_id, self._project_id): raise RuntimeError( "AIP Training job failed, see logs for details at " "https://console.cloud.google.com/ai-platform/jobs/" "{}/charts/cpu?project={}".format(job_id, self._project_id)) # Retrieve and report any remaining metrics training_metrics = self._get_remote_training_metrics( log_reader, training_metrics.partial_epoch_metrics) for epoch_metrics in training_metrics.completed_epoch_metrics: # TODO(b/169197272) Validate metrics contain oracle objective # TODO(b/170907612) Support submit partial results to Oracle if epoch_metrics: self.oracle.update_trial(trial_id=trial.trial_id, metrics=epoch_metrics, step=epoch) epoch += 1 # submit final epoch metrics if training_metrics.partial_epoch_metrics: self.oracle.update_trial( trial_id=trial.trial_id, metrics=training_metrics.partial_epoch_metrics, step=epoch)