def test_remote_run_trial_with_oracle_canceling_job( self, mock_tf_io, mock_is_running, mock_super_tuner, mock_job_status, mock_cloud_fit, mock_stop_job): remote_tuner = self._remote_tuner(None, None, self._study_config, max_trials=10) mock_is_running.side_effect = [True, False] mock_job_status.return_value = True remote_tuner._get_remote_training_metrics = mock.Mock() remote_tuner._get_remote_training_metrics.return_value = ( tuner._TrainingMetrics([{ "loss": 0.001 }], {})) remote_tuner.oracle = mock.create_autospec(oracle_module.Oracle, instance=True, spec_set=True) remote_tuner.oracle.update_trial = mock.Mock() remote_tuner.oracle.update_trial.return_value = "STOPPED" remote_tuner.hypermodel = mock.create_autospec( hypermodel_module.HyperModel, instance=True, spec_set=True) remote_tuner.run_trial(self._test_trial, "fit_arg", callbacks=["test_call_back"], fit_kwarg=1) self.assertEqual(2, remote_tuner.oracle.update_trial.call_count) self.assertEqual(2, remote_tuner._get_remote_training_metrics.call_count) mock_stop_job.assert_called_once_with(self._job_id, self._project_id)
def test_remote_run_trial_with_successful_job( self, mock_tf_io, mock_log_watcher, mock_is_running, mock_super_tuner, mock_job_status, mock_cloud_fit): remote_tuner = self._remote_tuner(None, None, self._study_config, max_trials=10) mock_is_running.side_effect = [True, False] remote_dir = os.path.join(remote_tuner.directory, str(self._test_trial.trial_id)) mock_job_status.return_value = True remote_tuner._get_remote_training_metrics = mock.Mock() remote_tuner._get_remote_training_metrics.return_value = ( tuner._TrainingMetrics([{ "loss": 0.001 }], {})) remote_tuner.oracle = mock.Mock() remote_tuner.oracle.update_trial = mock.Mock() remote_tuner.hypermodel = mock.Mock() remote_tuner.run_trial(self._test_trial, "fit_arg", callbacks=["test_call_back"], fit_kwarg=1) self.assertEqual(2, remote_tuner.oracle.update_trial.call_count) mock_cloud_fit.assert_called_with( "fit_arg", fit_kwarg=1, model=mock.ANY, callbacks=["test_call_back", mock.ANY, mock.ANY], remote_dir=remote_dir, job_spec=mock.ANY, region=self._region, project_id=self._project_id, image_uri=self._container_uri, job_id=self._job_id) train_log_path = os.path.join( remote_tuner._get_tensorboard_log_dir(self._test_trial.trial_id), "train") mock_log_watcher.assert_called_with(train_log_path) self.assertEqual(2, remote_tuner._get_remote_training_metrics.call_count) mock_tf_io.assert_called_with(train_log_path)