Пример #1
0
    def test_remote_run_trial_with_oracle_canceling_job(
            self, mock_tf_io, mock_is_running, mock_super_tuner,
            mock_job_status, mock_cloud_fit, mock_stop_job):

        remote_tuner = self._remote_tuner(None,
                                          None,
                                          self._study_config,
                                          max_trials=10)

        mock_is_running.side_effect = [True, False]
        mock_job_status.return_value = True
        remote_tuner._get_remote_training_metrics = mock.Mock()
        remote_tuner._get_remote_training_metrics.return_value = (
            tuner._TrainingMetrics([{
                "loss": 0.001
            }], {}))
        remote_tuner.oracle = mock.create_autospec(oracle_module.Oracle,
                                                   instance=True,
                                                   spec_set=True)
        remote_tuner.oracle.update_trial = mock.Mock()
        remote_tuner.oracle.update_trial.return_value = "STOPPED"
        remote_tuner.hypermodel = mock.create_autospec(
            hypermodel_module.HyperModel, instance=True, spec_set=True)
        remote_tuner.run_trial(self._test_trial,
                               "fit_arg",
                               callbacks=["test_call_back"],
                               fit_kwarg=1)

        self.assertEqual(2, remote_tuner.oracle.update_trial.call_count)
        self.assertEqual(2,
                         remote_tuner._get_remote_training_metrics.call_count)
        mock_stop_job.assert_called_once_with(self._job_id, self._project_id)
Пример #2
0
    def test_remote_run_trial_with_successful_job(
            self, mock_tf_io, mock_log_watcher, mock_is_running,
            mock_super_tuner, mock_job_status, mock_cloud_fit):
        remote_tuner = self._remote_tuner(None,
                                          None,
                                          self._study_config,
                                          max_trials=10)

        mock_is_running.side_effect = [True, False]

        remote_dir = os.path.join(remote_tuner.directory,
                                  str(self._test_trial.trial_id))
        mock_job_status.return_value = True
        remote_tuner._get_remote_training_metrics = mock.Mock()
        remote_tuner._get_remote_training_metrics.return_value = (
            tuner._TrainingMetrics([{
                "loss": 0.001
            }], {}))
        remote_tuner.oracle = mock.Mock()
        remote_tuner.oracle.update_trial = mock.Mock()
        remote_tuner.hypermodel = mock.Mock()
        remote_tuner.run_trial(self._test_trial,
                               "fit_arg",
                               callbacks=["test_call_back"],
                               fit_kwarg=1)

        self.assertEqual(2, remote_tuner.oracle.update_trial.call_count)
        mock_cloud_fit.assert_called_with(
            "fit_arg",
            fit_kwarg=1,
            model=mock.ANY,
            callbacks=["test_call_back", mock.ANY, mock.ANY],
            remote_dir=remote_dir,
            job_spec=mock.ANY,
            region=self._region,
            project_id=self._project_id,
            image_uri=self._container_uri,
            job_id=self._job_id)

        train_log_path = os.path.join(
            remote_tuner._get_tensorboard_log_dir(self._test_trial.trial_id),
            "train")
        mock_log_watcher.assert_called_with(train_log_path)
        self.assertEqual(2,
                         remote_tuner._get_remote_training_metrics.call_count)
        mock_tf_io.assert_called_with(train_log_path)