Пример #1
0
    def test_create_or_load_study_with_409_raises_RuntimeError(
            self, mock_discovery):
        """Verify that get_study gracefully handles 409 errors."""
        mock_request = mock.MagicMock()
        mock_request.execute.side_effect = errors.HttpError(
            httplib2.Response(info={"status": 409}), b"")
        mock_create_study = mock.MagicMock()
        mock_create_study.return_value = mock_request
        mock_discovery.build_from_document.return_value.projects().locations(
        ).studies().create = (mock_create_study)

        mock_get_study = mock.MagicMock()
        mock_get_study.side_effect = [
            errors.HttpError(httplib2.Response(info={"status": 400}), b"")
        ] * 3
        mock_discovery.build_from_document.return_value.projects().locations(
        ).studies().get = (mock_get_study)

        with self.assertRaisesRegex(
                RuntimeError,
                'GetStudy wasn\'t successful after 3 tries: <HttpError 400 "Ok">',
        ):
            optimizer_client.create_or_load_study(
                project_id=self._project_id,
                region=self._region,
                study_id=self._study_id,
                study_config=self._study_config,
            )
Пример #2
0
    def test_create_or_load_study_no_study_config_with_404_raises_ValueError(
        self, mock_discovery):
        mock_request = mock.MagicMock()
        mock_request.execute.side_effect = errors.HttpError(
            httplib2.Response(info={"status": 404}), b""
        )
        mock_get_study = mock.MagicMock()
        mock_get_study.return_value = mock_request
        mock_discovery.build_from_document.return_value.projects().locations(
            ).studies().get = mock_get_study

        with self.assertRaisesRegex(
            ValueError,
            "GetStudy failed. Study not found: {}.".format(self._study_id),
        ):
            optimizer_client.create_or_load_study(
                project_id=self._project_id,
                region=self._region,
                study_id=self._study_id,
            )

        mock_get_study.assert_called_with(
            name="projects/{}/locations/{}/studies/{}".format(
                self._project_id, self._region, self._study_id
            )
        )
Пример #3
0
    def test_create_or_load_study_newstudy(self, mock_discovery):
        mock_create_study = mock.MagicMock()
        mock_discovery.build_from_document.return_value.projects().locations(
        ).studies().create = (mock_create_study)

        client = optimizer_client.create_or_load_study(
            project_id=self._project_id,
            region=self._region,
            study_id=self._study_id,
            study_config=self._study_config,
        )

        self.assertIsInstance(client, optimizer_client._OptimizerClient)

        _, mock_kwargs = mock_discovery.build_from_document.call_args
        self.assertIn("service", mock_kwargs)
        self.assertIsInstance(mock_kwargs["service"], dict)
        self.assertEqual(
            mock_kwargs["service"]["rootUrl"],
            # Regional endpoint must be specified for Optimizer client.
            "https://us-central1-ml.googleapis.com/",
        )

        mock_create_study.assert_called_with(
            body={"study_config": self._study_config},
            parent="projects/{}/locations/{}".format(self._project_id,
                                                     self._region),
            studyId=self._study_id,
        )
Пример #4
0
    def test_create_or_load_study_with_409_success(self, mock_discovery):
        """Verify that get_study gracefully handles 409 errors."""
        mock_create_request = mock.MagicMock()
        mock_create_request.execute.side_effect = errors.HttpError(
            httplib2.Response(info={"status": 409}), b"")
        mock_create_study = mock.MagicMock()
        mock_create_study.return_value = mock_create_request
        mock_discovery.build_from_document.return_value.projects().locations(
        ).studies().create = (mock_create_study)

        mock_get_request = mock.MagicMock()
        mock_get_request.execute.side_effect = [
            errors.HttpError(httplib2.Response(info={"status": 400}), b""),
            errors.HttpError(httplib2.Response(info={"status": 400}), b""),
            mock.DEFAULT,
        ]
        mock_get_study = mock.MagicMock()
        mock_get_study.side_effect = mock_get_request
        mock_discovery.build_from_document.return_value.projects().locations(
        ).studies().get = (mock_get_study)

        client = optimizer_client.create_or_load_study(
            project_id=self._project_id,
            region=self._region,
            study_id=self._study_id,
            study_config=self._study_config,
        )
        self.assertIsInstance(client, optimizer_client._OptimizerClient)
Пример #5
0
    def tearDown(self):
        super(_CloudTunerIntegrationTestBase, self).tearDown()

        # Delete the study used in the test, if present
        if self._study_id:
            service = optimizer_client.create_or_load_study(
                _PROJECT_ID, _REGION, self._study_id, None)
            service.delete_study()

        tf.keras.backend.clear_session()
    def tearDown(self):
        super(_DistributingCloudTunerIntegrationTestBase, self).tearDown()

        # Delete the study used in the test, if present
        if self._study_id:
            service = optimizer_client.create_or_load_study(
                _PROJECT_ID, _REGION, self._study_id, None)
            service.delete_study()

        tf.keras.backend.clear_session()

        # Delete log files, saved_models and other training assets
        self._delete_dir(_REMOTE_DIR)
Пример #7
0
    def __init__(
        self,
        project_id: Text,
        region: Text,
        objective: Union[Text, oracle_module.Objective] = None,
        hyperparameters: hp_module.HyperParameters = None,
        study_config: Optional[Dict[Text, Any]] = None,
        max_trials: int = None,
        study_id: Optional[Text] = None,
    ):
        """KerasTuner Oracle interface implemented with Optimizer backend.

        Args:
            project_id: A GCP project id.
            region: A GCP region. e.g. 'us-central1'.
            objective: If a string, the direction of the optimization (min or
                max) will be inferred.
            hyperparameters: Mandatory and must include definitions for all
                hyperparameters used during the search. Can be used to override
                (or register in advance) hyperparameters in the search space.
            study_config: Study configuration for CAIP Optimizer service.
            max_trials: Total number of trials (model configurations) to test at
                most. If None, it continues the search until it reaches the
                Optimizer trial limit for each study. Users may stop the search
                externally (e.g. by killing the job). Note that the Oracle may
                interrupt the search before `max_trials` models have been
                tested.
            study_id: An identifier of the study. If not supplied,
                system-determined unique ID is given.
                The full study name will be
                `projects/{project_id}/locations/{region}/studies/{study_id}`,
                and the full trial name will be
                `{study name}/trials/{trial_id}`.
        """
        if study_config:
            if objective or hyperparameters:
                raise ValueError(
                    "Please configure either study_config or "
                    '"objective, and hyperparameters".'
                )
            objective = utils.convert_study_config_to_objective(study_config)
            hyperparameters = utils.convert_study_config_to_hps(study_config)
            self.study_config = study_config
        else:
            if not (objective and hyperparameters):
                raise ValueError(
                    "If study_config is not set, "
                    "objective and hyperparameters must be set."
                )
            self.study_config = utils.make_study_config(objective,
                                                        hyperparameters)

        super(CloudOracle, self).__init__(
            objective=objective,
            hyperparameters=hyperparameters,
            max_trials=max_trials,
            allow_new_entries=False,
            tune_new_entries=False,
        )

        if not project_id:
            raise ValueError('"project_id" is not found.')
        self._project_id = project_id

        if not region:
            raise ValueError('"region" is not found.')
        self._region = region

        self.objective = utils.format_objective(objective)
        self.hyperparameters = hyperparameters
        self.max_trials = max_trials

        self.study_id = study_id or "CloudTuner_study_{}".format(
            datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        )

        self.service = optimizer_client.create_or_load_study(
            self._project_id, self._region, self.study_id, self.study_config
        )

        self.trials = {}
        self._start_time = None