def test_job_id(self, mock_serialize_assets, mock_submit_job): # TF 1.x is not supported if utils.is_tf_v1(): with self.assertRaises(RuntimeError): client.cloud_fit( self._model, x=self._dataset, validation_data=self._dataset, remote_dir=self._remote_dir, job_spec=self._job_spec, batch_size=1, epochs=2, verbose=3, ) return test_job_id = "test_job_id" client.cloud_fit( self._model, x=self._dataset, validation_data=self._dataset, remote_dir=self._remote_dir, job_spec=self._job_spec, job_id=test_job_id, batch_size=1, epochs=2, verbose=3, ) kargs, _ = mock_submit_job.call_args body, _ = kargs self.assertDictContainsSubset({ "job_id": test_job_id, }, body)
def test_distribution_strategy(self, mock_serialize_assets, mock_submit_job): # TF 1.x is not supported if utils.is_tf_v1(): with self.assertRaises(RuntimeError): client.cloud_fit(self._model, x=self._dataset, remote_dir=self._remote_dir) return client.cloud_fit(self._model, x=self._dataset, remote_dir=self._remote_dir) kargs, _ = mock_submit_job.call_args body, _ = kargs self.assertDictContainsSubset( { "args": [ "--remote_dir", self._remote_dir, "--distribution_strategy", MULTI_WORKER_MIRRORED_STRATEGY_NAME, ], }, body["trainingInput"], ) client.cloud_fit( self._model, x=self._dataset, remote_dir=self._remote_dir, distribution_strategy=MIRRORED_STRATEGY_NAME, job_spec=self._job_spec, ) kargs, _ = mock_submit_job.call_args body, _ = kargs self.assertDictContainsSubset( { "args": [ "--remote_dir", self._remote_dir, "--distribution_strategy", MIRRORED_STRATEGY_NAME, ], }, body["trainingInput"], ) with self.assertRaises(ValueError): client.cloud_fit( self._model, x=self._dataset, remote_dir=self._remote_dir, distribution_strategy="not_implemented_strategy", job_spec=self._job_spec, )
def test_custom_job_spec(self, mock_submit_job): # TF 1.x is not supported if utils.is_tf_v1(): with self.assertRaises(RuntimeError): client.cloud_fit( self._model, x=self._dataset, validation_data=self._dataset, remote_dir=self._remote_dir, job_spec=self._job_spec, batch_size=1, epochs=2, verbose=3, ) return client.cloud_fit( self._model, x=self._dataset, validation_data=self._dataset, remote_dir=self._remote_dir, job_spec=self._job_spec, batch_size=1, epochs=2, verbose=3, ) kargs, _ = mock_submit_job.call_args body, _ = kargs self.assertDictContainsSubset( { "masterConfig": { "imageUri": self._image_uri, }, "args": [ "--remote_dir", self._remote_dir, "--distribution_strategy", MULTI_WORKER_MIRRORED_STRATEGY_NAME, ], }, body["trainingInput"], )
def test_client_with_tf_1x_raises_error(self): # This test is only applicable to TF 1.x if not utils.is_tf_v1(): return x = np.random.random((2, 3)) y = np.random.randint(0, 2, (2, 2)) # TF 1.x is not supported, verify proper error is raised for TF 1.x. with self.assertRaises(RuntimeError): client.cloud_fit( self._model(), x=x, y=y, remote_dir="gs://some_test_dir", region=self._region, project_id=self._project_id, image_uri=self._image_uri, epochs=2, )
def test_fit_kwargs(self, mock_submit_job): # TF 1.x is not supported if utils.is_tf_v1(): with self.assertRaises(RuntimeError): client.cloud_fit( self._model, x=self._dataset, validation_data=self._dataset, remote_dir=self._remote_dir, job_spec=self._job_spec, batch_size=1, epochs=2, verbose=3, ) return job_id = client.cloud_fit( self._model, x=self._dataset, validation_data=self._dataset, remote_dir=self._remote_dir, region=self._region, project_id=self._project_id, image_uri=self._image_uri, batch_size=1, epochs=2, verbose=3, ) kargs, _ = mock_submit_job.call_args body, _ = kargs self.assertEqual(body["job_id"], job_id) remote_dir = body["trainingInput"]["args"][1] training_assets_graph = tf.saved_model.load( os.path.join(remote_dir, "training_assets")) elements = training_assets_graph.fit_kwargs_fn() self.assertDictContainsSubset(tfds.as_numpy(elements), { "batch_size": 1, "epochs": 2, "verbose": 3 })
def test_serialize_assets(self): # TF 1.x is not supported if utils.is_tf_v1(): with self.assertRaises(RuntimeError): client.cloud_fit( self._model, x=self._dataset, validation_data=self._dataset, remote_dir=self._remote_dir, job_spec=self._job_spec, batch_size=1, epochs=2, verbose=3, ) return tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=self._remote_dir) args = self._scalar_fit_kwargs args["callbacks"] = [tensorboard_callback] client._serialize_assets(self._remote_dir, self._model, **args) self.assertGreaterEqual( len( tf.io.gfile.listdir( os.path.join(self._remote_dir, "training_assets"))), 1, ) self.assertGreaterEqual( len(tf.io.gfile.listdir(os.path.join(self._remote_dir, "model"))), 1) training_assets_graph = tf.saved_model.load( os.path.join(self._remote_dir, "training_assets")) pickled_callbacks = tfds.as_numpy(training_assets_graph.callbacks_fn()) unpickled_callbacks = cloudpickle.loads(pickled_callbacks) self.assertIsInstance(unpickled_callbacks[0], tf.keras.callbacks.TensorBoard)
def test_in_memory_data(self): # This test should only run in tf 2.x if utils.is_tf_v1(): return # Create a folder under remote dir for this test's data tmp_folder = str(uuid.uuid4()) remote_dir = os.path.join(self._remote_dir, tmp_folder) # Keep track of test folders created for final clean up self._test_folders.append(remote_dir) x = np.random.random((2, 3)) y = np.random.randint(0, 2, (2, 2)) job_id = client.cloud_fit( self._model(), x=x, y=y, remote_dir=remote_dir, region=self._region, project_id=self._project_id, image_uri=self._image_uri, job_id="cloud_fit_e2e_test_{}_{}".format( _BUILD_ID.replace("-", "_"), "test_in_memory_data" ), epochs=2, ) # TODO(b/169297404) Replace AIP job status logic with utils wrapper # Wait for AIP Training job to finish successfully self.assertTrue( google_api_client.wait_for_api_training_job_completion( job_id, self._project_id)) # load model from remote dir trained_model = tf.keras.models.load_model(os.path.join( remote_dir, "output")) eval_results = trained_model.evaluate(x, y) # Accuracy should be better than zero self.assertListEqual(trained_model.metrics_names, ["loss", "accuracy"]) self.assertGreater(eval_results[1], 0)
def run_trial(self, trial, *fit_args, **fit_kwargs): """Evaluates a set of hyperparameter values. This method is called during `search` to evaluate a set of hyperparameters using AI Platform training. Arguments: trial: A `Trial` instance that contains the information needed to run this trial. `Hyperparameters` can be accessed via `trial.hyperparameters`. *fit_args: Positional arguments passed by `search`. **fit_kwargs: Keyword arguments passed by `search`. Raises: RuntimeError: If AIP training job fails. """ # Running the training remotely. copied_fit_kwargs = copy.copy(fit_kwargs) # Handle any callbacks passed to `fit`. callbacks = fit_kwargs.pop("callbacks", []) callbacks = self._deepcopy_callbacks(callbacks) # Note run_trial does not use `TunerCallback` calls, since # training is performed on AI Platform training remotely. # Creating a tensorboard callback with log-dir path specific for this # trail_id. The tensorboard logs are used for passing metrics back from # remote execution. self._add_tensorboard_callback(callbacks, trial.trial_id) # Creating a save_model checkpoint callback with a saved model file path # specific to this trial, this is to prevent different trials from # overwriting each other. self._add_model_checkpoint_callback( callbacks, trial.trial_id) copied_fit_kwargs["callbacks"] = callbacks model = self.hypermodel.build(trial.hyperparameters) job_id = "{}_{}".format(self._study_id, trial.trial_id) tf.get_logger().info("Calling cloud_fit with %s", { "model": model, "remote_dir": self.directory, "region": self._region, "project_id": self._project_id, "image_uri": self.container_uri, "job_id": job_id, "*fit_args": fit_args, "**copied_fit_kwargs": copied_fit_kwargs}) client.cloud_fit( model=model, remote_dir=self.directory, region=self._region, project_id=self._project_id, image_uri=self.container_uri, job_id=job_id, *fit_args, **copied_fit_kwargs) # TODO(b/167569957) Add support for early termination. if not google_api_client.wait_for_api_training_job_success( job_id, self._project_id): raise RuntimeError( "AIP Training job failed, see logs for details at https://console.cloud.google.com/ai-platform/jobs/{}/charts/cpu?project={}" # pylint: disable=line-too-long .format(job_id, self._project_id)) # If the job was successful, retrieve the metrics training_metrics = self._get_remote_training_metrics(trial.trial_id) # Note since we are submitting all job results in one shot, this may # result in going over AI Platform Vizier limit of 1000 RPS. For more # details on API quotas refer to: # https://cloud.google.com/ai-platform/optimizer/docs/overview for epoch, epoch_metrics in enumerate(training_metrics): # TODO(b/169197272) Validate metrics contain oracle objective self.oracle.update_trial( trial_id=trial.trial_id, metrics=epoch_metrics, step=epoch)
def run_trial(self, trial, *fit_args, **fit_kwargs): """Evaluates a set of hyperparameter values. This method is called during `search` to evaluate a set of hyperparameters using AI Platform training. Arguments: trial: A `Trial` instance that contains the information needed to run this trial. `Hyperparameters` can be accessed via `trial.hyperparameters`. *fit_args: Positional arguments passed by `search`. **fit_kwargs: Keyword arguments passed by `search`. Raises: RuntimeError: If AIP training job fails. """ # Running the training remotely. copied_fit_kwargs = copy.copy(fit_kwargs) # Handle any callbacks passed to `fit`. callbacks = fit_kwargs.pop("callbacks", []) callbacks = self._deepcopy_callbacks(callbacks) # Note run_trial does not use `TunerCallback` calls, since # training is performed on AI Platform training remotely. # Creating a tensorboard callback with log-dir path specific for this # trail_id. The tensorboard logs are used for passing metrics back from # remote execution. self._add_tensorboard_callback(callbacks, trial.trial_id) # Creating a save_model checkpoint callback with a saved model file path # specific to this trial, this is to prevent different trials from # overwriting each other. self._add_model_checkpoint_callback( callbacks, trial.trial_id) copied_fit_kwargs["callbacks"] = callbacks model = self.hypermodel.build(trial.hyperparameters) remote_dir = os.path.join(self.directory, str(trial.trial_id)) # TODO(b/170687807) Switch from using "{}".format() to f-string job_id = "{}_{}".format(self._study_id, trial.trial_id) # Create job spec from worker count and config job_spec = self._get_job_spec_from_config(job_id) tf.get_logger().info("Calling cloud_fit with %s", { "model": model, "remote_dir": remote_dir, "region": self._region, "project_id": self._project_id, "image_uri": self._container_uri, "job_id": job_id, "*fit_args": fit_args, "job_spec": job_spec, "**copied_fit_kwargs": copied_fit_kwargs}) client.cloud_fit( model=model, remote_dir=remote_dir, region=self._region, project_id=self._project_id, image_uri=self._container_uri, job_id=job_id, job_spec=job_spec, *fit_args, **copied_fit_kwargs) # Create an instance of tensorboard DirectoryWatcher to retrieve the # logs for this trial run log_path = self._get_tensorboard_log_dir(trial.trial_id) # TODO(b/170687807) Switch from using "{}".format() to f-string tf.get_logger().info( "Retrieving training logs for trial {} from {}".format( trial.trial_id, log_path)) log_reader = tf_utils.get_tensorboard_log_watcher_from_path(log_path) training_metrics = _TrainingMetrics([], {}) epoch = 0 while google_api_client.is_api_training_job_running( job_id, self._project_id): time.sleep(_POLLING_INTERVAL_IN_SECONDS) # Retrieve available metrics if any training_metrics = self._get_remote_training_metrics( log_reader, training_metrics.partial_epoch_metrics) for epoch_metrics in training_metrics.completed_epoch_metrics: # TODO(b/169197272) Validate metrics contain oracle objective trial.status = self.oracle.update_trial( trial_id=trial.trial_id, metrics=epoch_metrics, step=epoch) epoch += 1 if trial.status == "STOPPED": google_api_client.stop_aip_training_job( job_id, self._project_id) break # Ensure the training job has completed successfully. if not google_api_client.wait_for_api_training_job_completion( job_id, self._project_id): raise RuntimeError( "AIP Training job failed, see logs for details at https://console.cloud.google.com/ai-platform/jobs/{}/charts/cpu?project={}" # pylint: disable=line-too-long .format(job_id, self._project_id)) # Retrieve and report any remaining metrics training_metrics = self._get_remote_training_metrics( log_reader, training_metrics.partial_epoch_metrics) for epoch_metrics in training_metrics.completed_epoch_metrics: # TODO(b/169197272) Validate metrics contain oracle objective # TODO(b/170907612) Support submit partial results to Oracle self.oracle.update_trial( trial_id=trial.trial_id, metrics=epoch_metrics, step=epoch) epoch += 1