示例#1
0
    def testSetStatusQuo(self, mock_fit, mock_observations_from_data):
        # NOTE: If empty data object is not passed, observations are not
        # extracted, even with mock.
        modelbridge = ModelBridge(
            search_space=get_search_space_for_value(),
            model=0,
            experiment=get_experiment_for_value(),
            data=Data(),
            status_quo_name="1_1",
        )
        self.assertEqual(modelbridge.status_quo, get_observation1())

        # Alternatively, we can specify by features
        modelbridge = ModelBridge(
            get_search_space_for_value(),
            0,
            [],
            get_experiment_for_value(),
            0,
            status_quo_features=get_observation1().features,
        )
        self.assertEqual(modelbridge.status_quo, get_observation1())

        # Alternatively, we can specify on experiment
        # Put a dummy arm with SQ name 1_1 on the dummy experiment.
        exp = get_experiment_for_value()
        sq = Arm(name="1_1", parameters={"x": 3.0})
        exp._status_quo = sq
        # Check that we set SQ to arm 1_1
        modelbridge = ModelBridge(get_search_space_for_value(), 0, [], exp, 0)
        self.assertEqual(modelbridge.status_quo, get_observation1())

        # Errors if features and name both specified
        with self.assertRaises(ValueError):
            modelbridge = ModelBridge(
                get_search_space_for_value(),
                0,
                [],
                exp,
                0,
                status_quo_features=get_observation1().features,
                status_quo_name="1_1",
            )

        # Left as None if features or name don't exist
        modelbridge = ModelBridge(
            get_search_space_for_value(), 0, [], exp, 0, status_quo_name="1_0"
        )
        self.assertIsNone(modelbridge.status_quo)
        modelbridge = ModelBridge(
            get_search_space_for_value(),
            0,
            [],
            get_experiment_for_value(),
            0,
            status_quo_features=ObservationFeatures(parameters={"x": 3.0, "y": 10.0}),
        )
        self.assertIsNone(modelbridge.status_quo)
示例#2
0
 def testUnwrapObservationData(self):
     observation_data = [get_observation1().data, get_observation2().data]
     f, cov = unwrap_observation_data(observation_data)
     self.assertEqual(f["a"], [2.0, 2.0])
     self.assertEqual(f["b"], [4.0, 1.0])
     self.assertEqual(cov["a"]["a"], [1.0, 2.0])
     self.assertEqual(cov["b"]["b"], [4.0, 5.0])
     self.assertEqual(cov["a"]["b"], [2.0, 3.0])
     self.assertEqual(cov["b"]["a"], [3.0, 4.0])
     # Check that errors if metric mismatch
     od3 = ObservationData(metric_names=["a"],
                           means=np.array([2.0]),
                           covariance=np.array([[4.0]]))
     with self.assertRaises(ValueError):
         unwrap_observation_data(observation_data + [od3])
示例#3
0
    def testModelBridge(self, mock_fit, mock_gen_arms,
                        mock_observations_from_data):
        # Test that on init transforms are stored and applied in the correct order
        transforms = [transform_1, transform_2]
        exp = get_experiment_for_value()
        ss = get_search_space_for_value()
        modelbridge = ModelBridge(ss, 0, transforms, exp, 0)
        self.assertEqual(list(modelbridge.transforms.keys()),
                         ["transform_1", "transform_2"])
        fit_args = mock_fit.mock_calls[0][2]
        self.assertTrue(
            fit_args["search_space"] == get_search_space_for_value(8.0))
        self.assertTrue(fit_args["observation_features"] == [])
        self.assertTrue(fit_args["observation_data"] == [])
        self.assertTrue(mock_observations_from_data.called)

        # Test prediction on out of design features.
        modelbridge._predict = mock.MagicMock(
            "ax.modelbridge.base.ModelBridge._predict",
            autospec=True,
            side_effect=ValueError("Out of Design"),
        )
        # This point is in design, and thus failures in predict are legitimate.
        with mock.patch.object(ModelBridge,
                               "model_space",
                               return_value=get_search_space_for_range_values):
            with self.assertRaises(ValueError):
                modelbridge.predict([get_observation2().features])

        # This point is out of design, and not in training data.
        with self.assertRaises(ValueError):
            modelbridge.predict([get_observation_status_quo0().features])

        # Now it's in the training data.
        with mock.patch.object(
                ModelBridge,
                "get_training_data",
                return_value=[get_observation_status_quo0()],
        ):
            # Return raw training value.
            self.assertEqual(
                modelbridge.predict([get_observation_status_quo0().features]),
                unwrap_observation_data([get_observation_status_quo0().data]),
            )

        # Test that transforms are applied correctly on predict
        modelbridge._predict = mock.MagicMock(
            "ax.modelbridge.base.ModelBridge._predict",
            autospec=True,
            return_value=[get_observation2trans().data],
        )
        modelbridge.predict([get_observation2().features])
        # Observation features sent to _predict are un-transformed afterwards
        modelbridge._predict.assert_called_with([get_observation2().features])

        # Check that _single_predict is equivalent here.
        modelbridge._single_predict([get_observation2().features])
        # Observation features sent to _predict are un-transformed afterwards
        modelbridge._predict.assert_called_with([get_observation2().features])

        # Test transforms applied on gen
        modelbridge._gen = mock.MagicMock(
            "ax.modelbridge.base.ModelBridge._gen",
            autospec=True,
            return_value=([get_observation1trans().features], [2], None, {}),
        )
        oc = OptimizationConfig(objective=Objective(metric=Metric(
            name="test_metric")))
        modelbridge._set_kwargs_to_save(model_key="TestModel",
                                        model_kwargs={},
                                        bridge_kwargs={})
        gr = modelbridge.gen(
            n=1,
            search_space=get_search_space_for_value(),
            optimization_config=oc,
            pending_observations={"a": [get_observation2().features]},
            fixed_features=ObservationFeatures({"x": 5}),
        )
        self.assertEqual(gr._model_key, "TestModel")
        modelbridge._gen.assert_called_with(
            n=1,
            search_space=SearchSpace(
                [FixedParameter("x", ParameterType.FLOAT, 8.0)]),
            optimization_config=oc,
            pending_observations={"a": [get_observation2trans().features]},
            fixed_features=ObservationFeatures({"x": 36}),
            model_gen_options=None,
        )
        mock_gen_arms.assert_called_with(
            arms_by_signature={},
            observation_features=[get_observation1().features])

        # Gen with no pending observations and no fixed features
        modelbridge.gen(n=1,
                        search_space=get_search_space_for_value(),
                        optimization_config=None)
        modelbridge._gen.assert_called_with(
            n=1,
            search_space=SearchSpace(
                [FixedParameter("x", ParameterType.FLOAT, 8.0)]),
            optimization_config=None,
            pending_observations={},
            fixed_features=ObservationFeatures({}),
            model_gen_options=None,
        )

        # Gen with multi-objective optimization config.
        oc2 = OptimizationConfig(objective=ScalarizedObjective(
            metrics=[Metric(name="test_metric"),
                     Metric(name="test_metric_2")]))
        modelbridge.gen(n=1,
                        search_space=get_search_space_for_value(),
                        optimization_config=oc2)
        modelbridge._gen.assert_called_with(
            n=1,
            search_space=SearchSpace(
                [FixedParameter("x", ParameterType.FLOAT, 8.0)]),
            optimization_config=oc2,
            pending_observations={},
            fixed_features=ObservationFeatures({}),
            model_gen_options=None,
        )

        # Test transforms applied on cross_validate
        modelbridge._cross_validate = mock.MagicMock(
            "ax.modelbridge.base.ModelBridge._cross_validate",
            autospec=True,
            return_value=[get_observation1trans().data],
        )
        cv_training_data = [get_observation2()]
        cv_test_points = [get_observation1().features]
        cv_predictions = modelbridge.cross_validate(
            cv_training_data=cv_training_data, cv_test_points=cv_test_points)
        modelbridge._cross_validate.assert_called_with(
            obs_feats=[get_observation2trans().features],
            obs_data=[get_observation2trans().data],
            cv_test_points=[get_observation1().features
                            ],  # untransformed after
        )
        self.assertTrue(cv_predictions == [get_observation1().data])

        # Test stored training data
        obs = modelbridge.get_training_data()
        self.assertTrue(obs == [get_observation1(), get_observation2()])
        self.assertEqual(modelbridge.metric_names, {"a", "b"})
        self.assertIsNone(modelbridge.status_quo)
        self.assertTrue(
            modelbridge.model_space == get_search_space_for_value())
        self.assertEqual(modelbridge.training_in_design, [False, False])

        with self.assertRaises(ValueError):
            modelbridge.training_in_design = [True, True, False]

        with self.assertRaises(ValueError):
            modelbridge.training_in_design = [True, True, False]

        # Test feature_importances
        with self.assertRaises(NotImplementedError):
            modelbridge.feature_importances("a")
示例#4
0
class BaseModelBridgeTest(TestCase):
    @mock.patch(
        "ax.modelbridge.base.observations_from_data",
        autospec=True,
        return_value=([get_observation1(),
                       get_observation2()]),
    )
    @mock.patch("ax.modelbridge.base.gen_arms",
                autospec=True,
                return_value=[Arm(parameters={})])
    @mock.patch("ax.modelbridge.base.ModelBridge._fit", autospec=True)
    def testModelBridge(self, mock_fit, mock_gen_arms,
                        mock_observations_from_data):
        # Test that on init transforms are stored and applied in the correct order
        transforms = [transform_1, transform_2]
        exp = get_experiment_for_value()
        ss = get_search_space_for_value()
        modelbridge = ModelBridge(ss, 0, transforms, exp, 0)
        self.assertEqual(list(modelbridge.transforms.keys()),
                         ["transform_1", "transform_2"])
        fit_args = mock_fit.mock_calls[0][2]
        self.assertTrue(
            fit_args["search_space"] == get_search_space_for_value(8.0))
        self.assertTrue(fit_args["observation_features"] == [])
        self.assertTrue(fit_args["observation_data"] == [])
        self.assertTrue(mock_observations_from_data.called)

        # Test prediction on out of design features.
        modelbridge._predict = mock.MagicMock(
            "ax.modelbridge.base.ModelBridge._predict",
            autospec=True,
            side_effect=ValueError("Out of Design"),
        )
        # This point is in design, and thus failures in predict are legitimate.
        with mock.patch.object(ModelBridge,
                               "model_space",
                               return_value=get_search_space_for_range_values):
            with self.assertRaises(ValueError):
                modelbridge.predict([get_observation2().features])

        # This point is out of design, and not in training data.
        with self.assertRaises(ValueError):
            modelbridge.predict([get_observation_status_quo0().features])

        # Now it's in the training data.
        with mock.patch.object(
                ModelBridge,
                "get_training_data",
                return_value=[get_observation_status_quo0()],
        ):
            # Return raw training value.
            self.assertEqual(
                modelbridge.predict([get_observation_status_quo0().features]),
                unwrap_observation_data([get_observation_status_quo0().data]),
            )

        # Test that transforms are applied correctly on predict
        modelbridge._predict = mock.MagicMock(
            "ax.modelbridge.base.ModelBridge._predict",
            autospec=True,
            return_value=[get_observation2trans().data],
        )
        modelbridge.predict([get_observation2().features])
        # Observation features sent to _predict are un-transformed afterwards
        modelbridge._predict.assert_called_with([get_observation2().features])

        # Check that _single_predict is equivalent here.
        modelbridge._single_predict([get_observation2().features])
        # Observation features sent to _predict are un-transformed afterwards
        modelbridge._predict.assert_called_with([get_observation2().features])

        # Test transforms applied on gen
        modelbridge._gen = mock.MagicMock(
            "ax.modelbridge.base.ModelBridge._gen",
            autospec=True,
            return_value=([get_observation1trans().features], [2], None, {}),
        )
        oc = OptimizationConfig(objective=Objective(metric=Metric(
            name="test_metric")))
        modelbridge._set_kwargs_to_save(model_key="TestModel",
                                        model_kwargs={},
                                        bridge_kwargs={})
        gr = modelbridge.gen(
            n=1,
            search_space=get_search_space_for_value(),
            optimization_config=oc,
            pending_observations={"a": [get_observation2().features]},
            fixed_features=ObservationFeatures({"x": 5}),
        )
        self.assertEqual(gr._model_key, "TestModel")
        modelbridge._gen.assert_called_with(
            n=1,
            search_space=SearchSpace(
                [FixedParameter("x", ParameterType.FLOAT, 8.0)]),
            optimization_config=oc,
            pending_observations={"a": [get_observation2trans().features]},
            fixed_features=ObservationFeatures({"x": 36}),
            model_gen_options=None,
        )
        mock_gen_arms.assert_called_with(
            arms_by_signature={},
            observation_features=[get_observation1().features])

        # Gen with no pending observations and no fixed features
        modelbridge.gen(n=1,
                        search_space=get_search_space_for_value(),
                        optimization_config=None)
        modelbridge._gen.assert_called_with(
            n=1,
            search_space=SearchSpace(
                [FixedParameter("x", ParameterType.FLOAT, 8.0)]),
            optimization_config=None,
            pending_observations={},
            fixed_features=ObservationFeatures({}),
            model_gen_options=None,
        )

        # Gen with multi-objective optimization config.
        oc2 = OptimizationConfig(objective=ScalarizedObjective(
            metrics=[Metric(name="test_metric"),
                     Metric(name="test_metric_2")]))
        modelbridge.gen(n=1,
                        search_space=get_search_space_for_value(),
                        optimization_config=oc2)
        modelbridge._gen.assert_called_with(
            n=1,
            search_space=SearchSpace(
                [FixedParameter("x", ParameterType.FLOAT, 8.0)]),
            optimization_config=oc2,
            pending_observations={},
            fixed_features=ObservationFeatures({}),
            model_gen_options=None,
        )

        # Test transforms applied on cross_validate
        modelbridge._cross_validate = mock.MagicMock(
            "ax.modelbridge.base.ModelBridge._cross_validate",
            autospec=True,
            return_value=[get_observation1trans().data],
        )
        cv_training_data = [get_observation2()]
        cv_test_points = [get_observation1().features]
        cv_predictions = modelbridge.cross_validate(
            cv_training_data=cv_training_data, cv_test_points=cv_test_points)
        modelbridge._cross_validate.assert_called_with(
            obs_feats=[get_observation2trans().features],
            obs_data=[get_observation2trans().data],
            cv_test_points=[get_observation1().features
                            ],  # untransformed after
        )
        self.assertTrue(cv_predictions == [get_observation1().data])

        # Test stored training data
        obs = modelbridge.get_training_data()
        self.assertTrue(obs == [get_observation1(), get_observation2()])
        self.assertEqual(modelbridge.metric_names, {"a", "b"})
        self.assertIsNone(modelbridge.status_quo)
        self.assertTrue(
            modelbridge.model_space == get_search_space_for_value())
        self.assertEqual(modelbridge.training_in_design, [False, False])

        with self.assertRaises(ValueError):
            modelbridge.training_in_design = [True, True, False]

        with self.assertRaises(ValueError):
            modelbridge.training_in_design = [True, True, False]

        # Test feature_importances
        with self.assertRaises(NotImplementedError):
            modelbridge.feature_importances("a")

    @mock.patch(
        "ax.modelbridge.base.observations_from_data",
        autospec=True,
        return_value=([get_observation1()]),
    )
    @mock.patch("ax.modelbridge.base.ModelBridge._fit", autospec=True)
    def testSetStatusQuo(self, mock_fit, mock_observations_from_data):
        # NOTE: If empty data object is not passed, observations are not
        # extracted, even with mock.
        modelbridge = ModelBridge(
            search_space=get_search_space_for_value(),
            model=0,
            experiment=get_experiment_for_value(),
            data=Data(),
            status_quo_name="1_1",
        )
        self.assertEqual(modelbridge.status_quo, get_observation1())

        # Alternatively, we can specify by features
        modelbridge = ModelBridge(
            get_search_space_for_value(),
            0,
            [],
            get_experiment_for_value(),
            0,
            status_quo_features=get_observation1().features,
        )
        self.assertEqual(modelbridge.status_quo, get_observation1())

        # Alternatively, we can specify on experiment
        # Put a dummy arm with SQ name 1_1 on the dummy experiment.
        exp = get_experiment_for_value()
        sq = Arm(name="1_1", parameters={"x": 3.0})
        exp._status_quo = sq
        # Check that we set SQ to arm 1_1
        modelbridge = ModelBridge(get_search_space_for_value(), 0, [], exp, 0)
        self.assertEqual(modelbridge.status_quo, get_observation1())

        # Errors if features and name both specified
        with self.assertRaises(ValueError):
            modelbridge = ModelBridge(
                get_search_space_for_value(),
                0,
                [],
                exp,
                0,
                status_quo_features=get_observation1().features,
                status_quo_name="1_1",
            )

        # Left as None if features or name don't exist
        modelbridge = ModelBridge(get_search_space_for_value(),
                                  0, [],
                                  exp,
                                  0,
                                  status_quo_name="1_0")
        self.assertIsNone(modelbridge.status_quo)
        modelbridge = ModelBridge(
            get_search_space_for_value(),
            0,
            [],
            get_experiment_for_value(),
            0,
            status_quo_features=ObservationFeatures(parameters={
                "x": 3.0,
                "y": 10.0
            }),
        )
        self.assertIsNone(modelbridge.status_quo)

    @mock.patch(
        "ax.modelbridge.base.observations_from_data",
        autospec=True,
        return_value=([
            get_observation_status_quo0(),
            get_observation_status_quo1(),
            get_observation1(),
            get_observation2(),
        ]),
    )
    @mock.patch("ax.modelbridge.base.ModelBridge._fit", autospec=True)
    def testSetStatusQuoMultipleObs(self, mock_fit,
                                    mock_observations_from_data):
        exp = get_experiment_with_repeated_arms(2)

        trial_index = 1
        status_quo_features = ObservationFeatures(
            parameters=exp.trials[trial_index].status_quo.parameters,
            trial_index=trial_index,
        )
        modelbridge = ModelBridge(
            get_search_space_for_value(),
            0,
            [],
            exp,
            0,
            status_quo_features=status_quo_features,
        )
        # Check that for experiments with many trials the status quo is set
        # to the value of the status quo of the last trial.
        if len(exp.trials) >= 1:
            self.assertEqual(modelbridge.status_quo,
                             get_observation_status_quo1())

    @mock.patch(
        "ax.modelbridge.base.observations_from_data",
        autospec=True,
        return_value=([get_observation1(),
                       get_observation1()]),
    )
    @mock.patch("ax.modelbridge.base.ModelBridge._fit", autospec=True)
    def testSetTrainingDataDupFeatures(self, mock_fit,
                                       mock_observations_from_data):
        # Throws an error if repeated features in observations.
        with self.assertRaises(ValueError):
            ModelBridge(
                get_search_space_for_value(),
                0,
                [],
                get_experiment_for_value(),
                0,
                status_quo_name="1_1",
            )

    def testUnwrapObservationData(self):
        observation_data = [get_observation1().data, get_observation2().data]
        f, cov = unwrap_observation_data(observation_data)
        self.assertEqual(f["a"], [2.0, 2.0])
        self.assertEqual(f["b"], [4.0, 1.0])
        self.assertEqual(cov["a"]["a"], [1.0, 2.0])
        self.assertEqual(cov["b"]["b"], [4.0, 5.0])
        self.assertEqual(cov["a"]["b"], [2.0, 3.0])
        self.assertEqual(cov["b"]["a"], [3.0, 4.0])
        # Check that errors if metric mismatch
        od3 = ObservationData(metric_names=["a"],
                              means=np.array([2.0]),
                              covariance=np.array([[4.0]]))
        with self.assertRaises(ValueError):
            unwrap_observation_data(observation_data + [od3])

    def testGenArms(self):
        p1 = {"x": 0, "y": 1}
        p2 = {"x": 4, "y": 8}
        observation_features = [
            ObservationFeatures(parameters=p1),
            ObservationFeatures(parameters=p2),
        ]
        arms = gen_arms(observation_features=observation_features)
        self.assertEqual(arms[0].parameters, p1)

        arm = Arm(name="1_1", parameters=p1)
        arms_by_signature = {arm.signature: arm}
        arms = gen_arms(
            observation_features=observation_features,
            arms_by_signature=arms_by_signature,
        )
        self.assertEqual(arms[0].name, "1_1")

    @mock.patch(
        "ax.modelbridge.base.ModelBridge._gen",
        autospec=True,
        return_value=([get_observation1trans().features], [2], None, {}),
    )
    @mock.patch("ax.modelbridge.base.ModelBridge.predict",
                autospec=True,
                return_value=None)
    def testGenWithDefaults(self, _, mock_gen):
        exp = get_experiment_for_value()
        exp.optimization_config = get_optimization_config_no_constraints()
        ss = get_search_space_for_range_value()
        modelbridge = ModelBridge(ss, None, [], exp)
        modelbridge.gen(1)
        mock_gen.assert_called_with(
            modelbridge,
            n=1,
            search_space=ss,
            fixed_features=ObservationFeatures(parameters={}),
            model_gen_options=None,
            optimization_config=OptimizationConfig(
                objective=Objective(metric=Metric("test_metric"),
                                    minimize=False),
                outcome_constraints=[],
            ),
            pending_observations={},
        )

    @mock.patch(
        "ax.modelbridge.base.ModelBridge._gen",
        autospec=True,
        side_effect=[
            ([get_observation1trans().features], [2], None, {}),
            ([get_observation2trans().features], [2], None, {}),
            ([get_observation2().features], [2], None, {}),
        ],
    )
    @mock.patch("ax.modelbridge.base.ModelBridge._update", autospec=True)
    def test_update(self, _mock_update, _mock_gen):
        exp = get_experiment_for_value()
        exp.optimization_config = get_optimization_config_no_constraints()
        ss = get_search_space_for_range_values()
        exp.search_space = ss
        modelbridge = ModelBridge(ss, None, [Log], exp)
        exp.new_trial(generator_run=modelbridge.gen(1))
        modelbridge._set_training_data(
            observations_from_data(
                data=Data(
                    pd.DataFrame([{
                        "arm_name": "0_0",
                        "metric_name": "m1",
                        "mean": 3.0,
                        "sem": 1.0,
                    }])),
                experiment=exp,
            ),
            ss,
        )
        exp.new_trial(generator_run=modelbridge.gen(1))
        modelbridge.update(
            data=Data(
                pd.DataFrame([{
                    "arm_name": "1_0",
                    "metric_name": "m1",
                    "mean": 5.0,
                    "sem": 0.0
                }])),
            experiment=exp,
        )
        exp.new_trial(generator_run=modelbridge.gen(1))
        # Trying to update with unrecognised metric should error.
        with self.assertRaisesRegex(ValueError, "Unrecognised metric"):
            modelbridge.update(
                data=Data(
                    pd.DataFrame([{
                        "arm_name": "1_0",
                        "metric_name": "m2",
                        "mean": 5.0,
                        "sem": 0.0,
                    }])),
                experiment=exp,
            )
示例#5
0
class TestAxClient(TestCase):
    """Tests service-like API functionality."""
    def setUp(self):
        # To avoid tests timing out due to GP fit / gen times.
        patch.dict(
            f"{Models.__module__}.MODEL_KEY_TO_MODEL_SETUP",
            {
                "GPEI": MODEL_KEY_TO_MODEL_SETUP["Sobol"]
            },
        ).start()

    def test_interruption(self) -> None:
        ax_client = AxClient()
        ax_client.create_experiment(
            name="test",
            parameters=[  # pyre-fixme[6]: expected union that should include
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            objective_name="branin",
            minimize=True,
        )
        for i in range(6):
            parameterization, trial_index = ax_client.get_next_trial()
            self.assertFalse(  # There should be non-complete trials.
                all(t.status.is_terminal
                    for t in ax_client.experiment.trials.values()))
            x, y = parameterization.get("x"), parameterization.get("y")
            ax_client.complete_trial(
                trial_index,
                raw_data=checked_cast(
                    float,
                    branin(checked_cast(float, x), checked_cast(float, y))),
            )
            old_client = ax_client
            serialized = ax_client.to_json_snapshot()
            ax_client = AxClient.from_json_snapshot(serialized)
            self.assertEqual(len(ax_client.experiment.trials.keys()), i + 1)
            self.assertIsNot(ax_client, old_client)
            self.assertTrue(  # There should be no non-complete trials.
                all(t.status.is_terminal
                    for t in ax_client.experiment.trials.values()))

    @patch(
        "ax.modelbridge.base.observations_from_data",
        autospec=True,
        return_value=([get_observation1()]),
    )
    @patch(
        "ax.modelbridge.random.RandomModelBridge.get_training_data",
        autospec=True,
        return_value=([get_observation1()]),
    )
    @patch(
        "ax.modelbridge.random.RandomModelBridge._predict",
        autospec=True,
        return_value=[get_observation1trans().data],
    )
    @patch(
        "ax.modelbridge.random.RandomModelBridge.feature_importances",
        autospec=True,
        return_value={
            "x": 0.9,
            "y": 1.1
        },
    )
    def test_default_generation_strategy_continuous(self, _a, _b, _c,
                                                    _d) -> None:
        """Test that Sobol+GPEI is used if no GenerationStrategy is provided."""
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[  # pyre-fixme[6]: expected union that should include
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            objective_name="a",
            minimize=True,
        )
        self.assertEqual(
            [s.model for s in not_none(ax_client.generation_strategy)._steps],
            [Models.SOBOL, Models.GPEI],
        )
        with self.assertRaisesRegex(ValueError, ".* no trials"):
            ax_client.get_optimization_trace(objective_optimum=branin.fmin)
        for i in range(6):
            parameterization, trial_index = ax_client.get_next_trial()
            x, y = parameterization.get("x"), parameterization.get("y")
            ax_client.complete_trial(
                trial_index,
                raw_data={
                    "a": (
                        checked_cast(
                            float,
                            branin(checked_cast(float, x),
                                   checked_cast(float, y)),
                        ),
                        0.0,
                    )
                },
                sample_size=i,
            )
        self.assertEqual(ax_client.generation_strategy.model._model_key,
                         "GPEI")
        ax_client.get_optimization_trace(objective_optimum=branin.fmin)
        ax_client.get_contour_plot()
        ax_client.get_feature_importances()
        trials_df = ax_client.get_trials_data_frame()
        self.assertIn("x", trials_df)
        self.assertIn("y", trials_df)
        self.assertIn("a", trials_df)
        self.assertEqual(len(trials_df), 6)

    def test_default_generation_strategy_discrete(self) -> None:
        """Test that Sobol is used if no GenerationStrategy is provided and
        the search space is discrete.
        """
        # Test that Sobol is chosen when all parameters are choice.
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[  # pyre-fixme[6]: expected union that should include
                {
                    "name": "x",
                    "type": "choice",
                    "values": [1, 2, 3]
                },
                {
                    "name": "y",
                    "type": "choice",
                    "values": [1, 2, 3]
                },
            ])
        self.assertEqual(
            [s.model for s in not_none(ax_client.generation_strategy)._steps],
            [Models.SOBOL],
        )
        self.assertEqual(ax_client.get_max_parallelism(), [(-1, -1)])
        self.assertTrue(ax_client.get_trials_data_frame().empty)

    def test_create_experiment(self) -> None:
        """Test basic experiment creation."""
        ax_client = AxClient(
            GenerationStrategy(
                steps=[GenerationStep(model=Models.SOBOL, num_trials=30)]))
        with self.assertRaisesRegex(ValueError,
                                    "Experiment not set on Ax client"):
            ax_client.experiment
        ax_client.create_experiment(
            name="test_experiment",
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [0.001, 0.1],
                    "value_type": "float",
                    "log_scale": True,
                },
                {
                    "name": "y",
                    "type": "choice",
                    "values": [1, 2, 3],
                    "value_type": "int",
                    "is_ordered": True,
                },
                {
                    "name": "x3",
                    "type": "fixed",
                    "value": 2,
                    "value_type": "int"
                },
                {
                    "name": "x4",
                    "type": "range",
                    "bounds": [1.0, 3.0],
                    "value_type": "int",
                },
                {
                    "name": "x5",
                    "type": "choice",
                    "values": ["one", "two", "three"],
                    "value_type": "str",
                },
                {
                    "name": "x6",
                    "type": "range",
                    "bounds": [1.0, 3.0],
                    "value_type": "int",
                },
            ],
            objective_name="test_objective",
            minimize=True,
            outcome_constraints=["some_metric >= 3", "some_metric <= 4.0"],
            parameter_constraints=["x4 <= x6"],
        )
        assert ax_client._experiment is not None
        self.assertEqual(ax_client._experiment, ax_client.experiment)
        self.assertEqual(
            ax_client._experiment.search_space.parameters["x"],
            RangeParameter(
                name="x",
                parameter_type=ParameterType.FLOAT,
                lower=0.001,
                upper=0.1,
                log_scale=True,
            ),
        )
        self.assertEqual(
            ax_client._experiment.search_space.parameters["y"],
            ChoiceParameter(
                name="y",
                parameter_type=ParameterType.INT,
                values=[1, 2, 3],
                is_ordered=True,
            ),
        )
        self.assertEqual(
            ax_client._experiment.search_space.parameters["x3"],
            FixedParameter(name="x3",
                           parameter_type=ParameterType.INT,
                           value=2),
        )
        self.assertEqual(
            ax_client._experiment.search_space.parameters["x4"],
            RangeParameter(name="x4",
                           parameter_type=ParameterType.INT,
                           lower=1.0,
                           upper=3.0),
        )
        self.assertEqual(
            ax_client._experiment.search_space.parameters["x5"],
            ChoiceParameter(
                name="x5",
                parameter_type=ParameterType.STRING,
                values=["one", "two", "three"],
            ),
        )
        self.assertEqual(
            ax_client._experiment.optimization_config.outcome_constraints[0],
            OutcomeConstraint(
                metric=Metric(name="some_metric"),
                op=ComparisonOp.GEQ,
                bound=3.0,
                relative=False,
            ),
        )
        self.assertEqual(
            ax_client._experiment.optimization_config.outcome_constraints[1],
            OutcomeConstraint(
                metric=Metric(name="some_metric"),
                op=ComparisonOp.LEQ,
                bound=4.0,
                relative=False,
            ),
        )
        self.assertTrue(
            ax_client._experiment.optimization_config.objective.minimize)

    def test_constraint_same_as_objective(self):
        """Check that we do not allow constraints on the objective metric."""
        ax_client = AxClient(
            GenerationStrategy(
                steps=[GenerationStep(model=Models.SOBOL, num_trials=30)]))
        with self.assertRaises(ValueError):
            ax_client.create_experiment(
                name="test_experiment",
                parameters=[{
                    "name": "x3",
                    "type": "fixed",
                    "value": 2,
                    "value_type": "int"
                }],
                objective_name="test_objective",
                outcome_constraints=["test_objective >= 3"],
            )

    def test_raw_data_format(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
        )
        for _ in range(6):
            parameterization, trial_index = ax_client.get_next_trial()
            x, y = parameterization.get("x"), parameterization.get("y")
            ax_client.complete_trial(trial_index, raw_data=(branin(x, y), 0.0))
        with self.assertRaisesRegex(ValueError,
                                    "Raw data has an invalid type"):
            ax_client.update_trial_data(trial_index, raw_data="invalid_data")

    def test_raw_data_format_with_fidelities(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 1.0]
                },
            ],
            minimize=True,
        )
        for _ in range(6):
            parameterization, trial_index = ax_client.get_next_trial()
            x, y = parameterization.get("x"), parameterization.get("y")
            ax_client.complete_trial(
                trial_index,
                raw_data=[
                    ({
                        "y": y / 2.0
                    }, {
                        "objective": (branin(x, y / 2.0), 0.0)
                    }),
                    ({
                        "y": y
                    }, {
                        "objective": (branin(x, y), 0.0)
                    }),
                ],
            )

    def test_keep_generating_without_data(self):
        # Check that normally numebr of arms to generate is enforced.
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
        )
        for _ in range(5):
            parameterization, trial_index = ax_client.get_next_trial()
        with self.assertRaisesRegex(DataRequiredError,
                                    "All trials for current model"):
            ax_client.get_next_trial()
        # Check thatwith enforce_sequential_optimization off, we can keep
        # generating.
        ax_client = AxClient(enforce_sequential_optimization=False)
        ax_client.create_experiment(
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
        )
        self.assertFalse(
            ax_client.generation_strategy._steps[0].enforce_num_trials, False)
        self.assertFalse(
            ax_client.generation_strategy._steps[1].max_parallelism, None)
        for _ in range(10):
            parameterization, trial_index = ax_client.get_next_trial()

    def test_trial_completion(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
        )
        params, idx = ax_client.get_next_trial()
        # Can't update before completing.
        with self.assertRaisesRegex(ValueError, ".* not yet"):
            ax_client.update_trial_data(trial_index=idx,
                                        raw_data={"objective": (0, 0.0)})
        ax_client.complete_trial(trial_index=idx,
                                 raw_data={"objective": (0, 0.0)})
        # Cannot complete a trial twice, should use `update_trial_data`.
        with self.assertRaisesRegex(ValueError, ".* already been completed"):
            ax_client.complete_trial(trial_index=idx,
                                     raw_data={"objective": (0, 0.0)})
        # Cannot update trial data with observation for a metric it already has.
        with self.assertRaisesRegex(ValueError, ".* contained an observation"):
            ax_client.update_trial_data(trial_index=idx,
                                        raw_data={"objective": (0, 0.0)})
        # Same as above, except objective name should be getting inferred.
        with self.assertRaisesRegex(ValueError, ".* contained an observation"):
            ax_client.update_trial_data(trial_index=idx, raw_data=1.0)
        ax_client.update_trial_data(trial_index=idx, raw_data={"m1": (1, 0.0)})
        metrics_in_data = ax_client.experiment.fetch_data(
        ).df["metric_name"].values
        self.assertIn("m1", metrics_in_data)
        self.assertIn("objective", metrics_in_data)
        self.assertEqual(ax_client.get_best_parameters()[0], params)
        params2, idy = ax_client.get_next_trial()
        ax_client.complete_trial(trial_index=idy, raw_data=(-1, 0.0))
        self.assertEqual(ax_client.get_best_parameters()[0], params2)
        params3, idx3 = ax_client.get_next_trial()
        ax_client.complete_trial(trial_index=idx3,
                                 raw_data=-2,
                                 metadata={"dummy": "test"})
        self.assertEqual(ax_client.get_best_parameters()[0], params3)
        self.assertEqual(
            ax_client.experiment.trials.get(2).run_metadata.get("dummy"),
            "test")
        best_trial_values = ax_client.get_best_parameters()[1]
        self.assertEqual(best_trial_values[0], {"objective": -2.0})
        self.assertTrue(
            math.isnan(best_trial_values[1]["objective"]["objective"]))

    def test_abandon_trial(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
        )

        # An abandoned trial adds no data.
        params, idx = ax_client.get_next_trial()
        ax_client.abandon_trial(trial_index=idx)
        data = ax_client.experiment.fetch_data()
        self.assertEqual(len(data.df.index), 0)

        # Can't update a completed trial.
        params2, idx2 = ax_client.get_next_trial()
        ax_client.complete_trial(trial_index=idx2,
                                 raw_data={"objective": (0, 0.0)})
        with self.assertRaisesRegex(ValueError, ".* in a terminal state."):
            ax_client.abandon_trial(trial_index=idx2)

    def test_ttl_trial(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
        )

        # A ttl trial that ends adds no data.
        params, idx = ax_client.get_next_trial(ttl_seconds=1)
        self.assertTrue(ax_client.experiment.trials.get(idx).status.is_running)
        time.sleep(1)  # Wait for TTL to elapse.
        self.assertTrue(ax_client.experiment.trials.get(idx).status.is_failed)
        # Also make sure we can no longer complete the trial as it is failed.
        with self.assertRaisesRegex(
                ValueError,
                ".* has been marked FAILED, so it no longer expects data."):
            ax_client.complete_trial(trial_index=idx,
                                     raw_data={"objective": (0, 0.0)})

        params2, idy = ax_client.get_next_trial(ttl_seconds=1)
        ax_client.complete_trial(trial_index=idy, raw_data=(-1, 0.0))
        self.assertEqual(ax_client.get_best_parameters()[0], params2)

    def test_start_and_end_time_in_trial_completion(self):
        start_time = current_timestamp_in_millis()
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
        )
        params, idx = ax_client.get_next_trial()
        ax_client.complete_trial(
            trial_index=idx,
            raw_data=1.0,
            metadata={
                "start_time": start_time,
                "end_time": current_timestamp_in_millis(),
            },
        )
        dat = ax_client.experiment.fetch_data().df
        self.assertGreater(dat["end_time"][0], dat["start_time"][0])

    def test_fail_on_batch(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
        )
        batch_trial = ax_client.experiment.new_batch_trial(
            generator_run=GeneratorRun(arms=[
                Arm(parameters={
                    "x": 0,
                    "y": 1
                }),
                Arm(parameters={
                    "x": 0,
                    "y": 1
                }),
            ]))
        with self.assertRaises(NotImplementedError):
            ax_client.complete_trial(batch_trial.index, 0)

    def test_log_failure(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
        )
        _, idx = ax_client.get_next_trial()
        ax_client.log_trial_failure(idx, metadata={"dummy": "test"})
        self.assertTrue(ax_client.experiment.trials.get(idx).status.is_failed)
        self.assertEqual(
            ax_client.experiment.trials.get(idx).run_metadata.get("dummy"),
            "test")
        with self.assertRaisesRegex(ValueError, ".* no longer expects"):
            ax_client.complete_trial(idx, {})

    def test_attach_trial_and_get_trial_parameters(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
        )
        params, idx = ax_client.attach_trial(parameters={"x": 0.0, "y": 1.0})
        ax_client.complete_trial(trial_index=idx, raw_data=5)
        self.assertEqual(ax_client.get_best_parameters()[0], params)
        self.assertEqual(ax_client.get_trial_parameters(trial_index=idx), {
            "x": 0,
            "y": 1
        })
        with self.assertRaises(ValueError):
            ax_client.get_trial_parameters(
                trial_index=10)  # No trial #10 in experiment.
        with self.assertRaisesRegex(ValueError, ".* is of type"):
            ax_client.attach_trial({"x": 1, "y": 2})

    def test_attach_trial_ttl_seconds(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
        )
        params, idx = ax_client.attach_trial(parameters={
            "x": 0.0,
            "y": 1.0
        },
                                             ttl_seconds=1)
        self.assertTrue(ax_client.experiment.trials.get(idx).status.is_running)
        time.sleep(1)  # Wait for TTL to elapse.
        self.assertTrue(ax_client.experiment.trials.get(idx).status.is_failed)
        # Also make sure we can no longer complete the trial as it is failed.
        with self.assertRaisesRegex(
                ValueError,
                ".* has been marked FAILED, so it no longer expects data."):
            ax_client.complete_trial(trial_index=idx, raw_data=5)

        params2, idx2 = ax_client.attach_trial(parameters={
            "x": 0.0,
            "y": 1.0
        },
                                               ttl_seconds=1)
        ax_client.complete_trial(trial_index=idx2, raw_data=5)
        self.assertEqual(ax_client.get_best_parameters()[0], params2)
        self.assertEqual(ax_client.get_trial_parameters(trial_index=idx2), {
            "x": 0,
            "y": 1
        })

    def test_attach_trial_numpy(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
        )
        params, idx = ax_client.attach_trial(parameters={"x": 0.0, "y": 1.0})
        ax_client.complete_trial(trial_index=idx, raw_data=np.int32(5))
        self.assertEqual(ax_client.get_best_parameters()[0], params)

    def test_relative_oc_without_sq(self):
        """Must specify status quo to have relative outcome constraint."""
        ax_client = AxClient()
        with self.assertRaises(ValueError):
            ax_client.create_experiment(
                name="test_experiment",
                parameters=[
                    {
                        "name": "x",
                        "type": "range",
                        "bounds": [-5.0, 10.0]
                    },
                    {
                        "name": "y",
                        "type": "range",
                        "bounds": [0.0, 15.0]
                    },
                ],
                objective_name="test_objective",
                minimize=True,
                outcome_constraints=["some_metric <= 4.0%"],
            )

    def test_recommended_parallelism(self):
        ax_client = AxClient()
        with self.assertRaisesRegex(ValueError, "No generation strategy"):
            ax_client.get_max_parallelism()
        ax_client.create_experiment(
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
        )
        self.assertEqual(ax_client.get_max_parallelism(), [(5, 5), (-1, 3)])
        self.assertEqual(
            run_trials_using_recommended_parallelism(
                ax_client, ax_client.get_max_parallelism(), 20),
            0,
        )
        # With incorrect parallelism setting, the 'need more data' error should
        # still be raised.
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
        )
        with self.assertRaisesRegex(DataRequiredError,
                                    "All trials for current model "):
            run_trials_using_recommended_parallelism(ax_client, [(6, 6),
                                                                 (-1, 3)], 20)

    @patch.dict(sys.modules, {"ax.storage.sqa_store.structs": None})
    @patch.dict(sys.modules, {"sqalchemy": None})
    @patch("ax.service.ax_client.DBSettings", None)
    def test_no_sqa(self):
        # Make sure we couldn't import sqa_store.structs (this could happen when
        # SQLAlchemy is not installed).
        with self.assertRaises(ModuleNotFoundError):
            import ax_client.storage.sqa_store.structs  # noqa F401
        # Make sure we can still import ax_client.
        __import__("ax.service.ax_client")
        AxClient(
        )  # Make sure we still can instantiate client w/o db settings.
        # DBSettings should be defined in `ax_client` now, but incorrectly typed
        # `db_settings` argument should still make instantiation fail.
        with self.assertRaisesRegex(ValueError,
                                    "`db_settings` argument should "):
            AxClient(db_settings="badly_typed_db_settings")

    def test_plotting_validation(self):
        ax_client = AxClient()
        ax_client.create_experiment(parameters=[{
            "name": "x3",
            "type": "fixed",
            "value": 2,
            "value_type": "int"
        }])
        with self.assertRaisesRegex(ValueError, ".* there are no trials"):
            ax_client.get_contour_plot()
        with self.assertRaisesRegex(ValueError, ".* there are no trials"):
            ax_client.get_feature_importances()
        ax_client.get_next_trial()
        with self.assertRaisesRegex(ValueError, ".* less than 2 parameters"):
            ax_client.get_contour_plot()
        ax_client = AxClient()
        ax_client.create_experiment(parameters=[
            {
                "name": "x",
                "type": "range",
                "bounds": [-5.0, 10.0]
            },
            {
                "name": "y",
                "type": "range",
                "bounds": [0.0, 15.0]
            },
        ])
        ax_client.get_next_trial()
        with self.assertRaisesRegex(ValueError, "If `param_x` is provided"):
            ax_client.get_contour_plot(param_x="y")
        with self.assertRaisesRegex(ValueError, "If `param_x` is provided"):
            ax_client.get_contour_plot(param_y="y")
        with self.assertRaisesRegex(ValueError, 'Parameter "x3"'):
            ax_client.get_contour_plot(param_x="x3", param_y="x3")
        with self.assertRaisesRegex(ValueError, 'Parameter "x4"'):
            ax_client.get_contour_plot(param_x="x", param_y="x4")
        with self.assertRaisesRegex(ValueError, 'Metric "nonexistent"'):
            ax_client.get_contour_plot(param_x="x",
                                       param_y="y",
                                       metric_name="nonexistent")
        with self.assertRaisesRegex(UnsupportedPlotError,
                                    "Could not obtain contour"):
            ax_client.get_contour_plot(param_x="x",
                                       param_y="y",
                                       metric_name="objective")
        with self.assertRaisesRegex(ValueError, "Could not obtain feature"):
            ax_client.get_feature_importances()

    def test_sqa_storage(self):
        init_test_engine_and_session_factory(force_init=True)
        config = SQAConfig()
        encoder = Encoder(config=config)
        decoder = Decoder(config=config)
        db_settings = DBSettings(encoder=encoder, decoder=decoder)
        ax_client = AxClient(db_settings=db_settings)
        ax_client.create_experiment(
            name="test_experiment",
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
        )
        for _ in range(5):
            parameters, trial_index = ax_client.get_next_trial()
            ax_client.complete_trial(trial_index=trial_index,
                                     raw_data=branin(*parameters.values()))
        gs = ax_client.generation_strategy
        ax_client = AxClient(db_settings=db_settings)
        ax_client.load_experiment_from_database("test_experiment")
        # Trial #4 was completed after the last time the generation strategy
        # generated candidates, so pre-save generation strategy was not
        # "aware" of completion of trial #4. Post-restoration generation
        # strategy is aware of it, however, since it gets restored with most
        # up-to-date experiment data. Do adding trial #4 to the seen completed
        # trials of pre-storage GS to check their equality otherwise.
        gs._seen_trial_indices_by_status[TrialStatus.COMPLETED].add(4)
        self.assertEqual(gs, ax_client.generation_strategy)
        with self.assertRaises(ValueError):
            # Overwriting existing experiment.
            ax_client.create_experiment(
                name="test_experiment",
                parameters=[
                    {
                        "name": "x",
                        "type": "range",
                        "bounds": [-5.0, 10.0]
                    },
                    {
                        "name": "y",
                        "type": "range",
                        "bounds": [0.0, 15.0]
                    },
                ],
                minimize=True,
            )
        with self.assertRaises(ValueError):
            # Overwriting existing experiment with overwrite flag with present
            # DB settings. This should fail as we no longer allow overwriting
            # experiments stored in the DB.
            ax_client.create_experiment(
                name="test_experiment",
                parameters=[{
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                }],
                overwrite_existing_experiment=True,
            )
        # Original experiment should still be in DB and not have been overwritten.
        self.assertEqual(len(ax_client.experiment.trials), 5)

    def test_overwrite(self):
        init_test_engine_and_session_factory(force_init=True)
        ax_client = AxClient()
        ax_client.create_experiment(
            name="test_experiment",
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
        )

        # Log a trial
        parameters, trial_index = ax_client.get_next_trial()
        ax_client.complete_trial(trial_index=trial_index,
                                 raw_data=branin(*parameters.values()))

        with self.assertRaises(ValueError):
            # Overwriting existing experiment.
            ax_client.create_experiment(
                name="test_experiment",
                parameters=[
                    {
                        "name": "x",
                        "type": "range",
                        "bounds": [-5.0, 10.0]
                    },
                    {
                        "name": "y",
                        "type": "range",
                        "bounds": [0.0, 15.0]
                    },
                ],
                minimize=True,
            )
        # Overwriting existing experiment with overwrite flag.
        ax_client.create_experiment(
            name="test_experiment",
            parameters=[
                {
                    "name": "x1",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "x2",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            overwrite_existing_experiment=True,
        )
        # There should be no trials, as we just put in a fresh experiment.
        self.assertEqual(len(ax_client.experiment.trials), 0)

        # Log a trial
        parameters, trial_index = ax_client.get_next_trial()
        self.assertIn("x1", parameters.keys())
        self.assertIn("x2", parameters.keys())
        ax_client.complete_trial(trial_index=trial_index,
                                 raw_data=branin(*parameters.values()))

    def test_fixed_random_seed_reproducibility(self):
        ax_client = AxClient(random_seed=239)
        ax_client.create_experiment(parameters=[
            {
                "name": "x",
                "type": "range",
                "bounds": [-5.0, 10.0]
            },
            {
                "name": "y",
                "type": "range",
                "bounds": [0.0, 15.0]
            },
        ])
        for _ in range(5):
            params, idx = ax_client.get_next_trial()
            ax_client.complete_trial(idx,
                                     branin(params.get("x"), params.get("y")))
        trial_parameters_1 = [
            t.arm.parameters for t in ax_client.experiment.trials.values()
        ]
        ax_client = AxClient(random_seed=239)
        ax_client.create_experiment(parameters=[
            {
                "name": "x",
                "type": "range",
                "bounds": [-5.0, 10.0]
            },
            {
                "name": "y",
                "type": "range",
                "bounds": [0.0, 15.0]
            },
        ])
        for _ in range(5):
            params, idx = ax_client.get_next_trial()
            ax_client.complete_trial(idx,
                                     branin(params.get("x"), params.get("y")))
        trial_parameters_2 = [
            t.arm.parameters for t in ax_client.experiment.trials.values()
        ]
        self.assertEqual(trial_parameters_1, trial_parameters_2)

    def test_init_position_saved(self):
        ax_client = AxClient(random_seed=239)
        ax_client.create_experiment(
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            name="sobol_init_position_test",
        )
        for _ in range(4):
            # For each generated trial, snapshot the client before generating it,
            # then recreate client, regenerate the trial and compare the trial
            # generated before and after snapshotting. If the state of Sobol is
            # recorded correctly, the newly generated trial will be the same as
            # the one generated before the snapshotting.
            serialized = ax_client.to_json_snapshot()
            params, idx = ax_client.get_next_trial()
            ax_client = AxClient.from_json_snapshot(serialized)
            with self.subTest(ax=ax_client, params=params, idx=idx):
                new_params, new_idx = ax_client.get_next_trial()
                self.assertEqual(params, new_params)
                self.assertEqual(idx, new_idx)
                self.assertEqual(
                    ax_client.experiment.trials[idx]._generator_run.
                    _model_state_after_gen["init_position"],
                    idx + 1,
                )
            ax_client.complete_trial(idx,
                                     branin(params.get("x"), params.get("y")))

    def test_unnamed_experiment_snapshot(self):
        ax_client = AxClient(random_seed=239)
        ax_client.create_experiment(parameters=[
            {
                "name": "x",
                "type": "range",
                "bounds": [-5.0, 10.0]
            },
            {
                "name": "y",
                "type": "range",
                "bounds": [0.0, 15.0]
            },
        ])
        serialized = ax_client.to_json_snapshot()
        ax_client = AxClient.from_json_snapshot(serialized)
        self.assertIsNone(ax_client.experiment._name)

    @patch(
        "ax.modelbridge.base.observations_from_data",
        autospec=True,
        return_value=([get_observation1()]),
    )
    @patch(
        "ax.modelbridge.random.RandomModelBridge.get_training_data",
        autospec=True,
        return_value=([get_observation1()]),
    )
    @patch(
        "ax.modelbridge.random.RandomModelBridge._predict",
        autospec=True,
        return_value=[get_observation1trans().data],
    )
    def test_get_model_predictions(self, _predict, _tr_data, _obs_from_data):
        ax_client = AxClient()
        ax_client.create_experiment(
            name="test_experiment",
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
            objective_name="a",
        )
        ax_client.get_next_trial()
        ax_client.experiment.trials[0].arm._name = "1_1"
        self.assertEqual(ax_client.get_model_predictions(),
                         {0: {
                             "a": (9.0, 1.0)
                         }})

    def test_deprecated_save_load_method_errors(self):
        ax_client = AxClient()
        with self.assertRaises(NotImplementedError):
            ax_client.save()
        with self.assertRaises(NotImplementedError):
            ax_client.load()
        with self.assertRaises(NotImplementedError):
            ax_client.load_experiment("test_experiment")
        with self.assertRaises(NotImplementedError):
            ax_client.get_recommended_max_parallelism()

    def test_find_last_trial_with_parameterization(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            name="test_experiment",
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
            objective_name="a",
        )
        params, trial_idx = ax_client.get_next_trial()
        found_trial_idx = ax_client._find_last_trial_with_parameterization(
            parameterization=params)
        self.assertEqual(found_trial_idx, trial_idx)
        # Check that it's indeed the _last_ trial with params that is found.
        _, new_trial_idx = ax_client.attach_trial(parameters=params)
        found_trial_idx = ax_client._find_last_trial_with_parameterization(
            parameterization=params)
        self.assertEqual(found_trial_idx, new_trial_idx)
        with self.assertRaisesRegex(ValueError, "No .* matches"):
            found_trial_idx = ax_client._find_last_trial_with_parameterization(
                parameterization={k: v + 1.0
                                  for k, v in params.items()})

    def test_verify_parameterization(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            name="test_experiment",
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
            objective_name="a",
        )
        params, trial_idx = ax_client.get_next_trial()
        self.assertTrue(
            ax_client.verify_trial_parameterization(trial_index=trial_idx,
                                                    parameterization=params))
        # Make sure it still works if ordering in the parameterization is diff.
        self.assertTrue(
            ax_client.verify_trial_parameterization(
                trial_index=trial_idx,
                parameterization={
                    k: params[k]
                    for k in reversed(list(params.keys()))
                },
            ))
        self.assertFalse(
            ax_client.verify_trial_parameterization(
                trial_index=trial_idx,
                parameterization={k: v + 1.0
                                  for k, v in params.items()},
            ))

    def test_tracking_metric_addition(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            name="test_experiment",
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
            objective_name="a",
        )
        params, trial_idx = ax_client.get_next_trial()
        self.assertEqual(list(ax_client.experiment.metrics.keys()), ["a"])
        ax_client.complete_trial(trial_index=trial_idx,
                                 raw_data={
                                     "a": 1.0,
                                     "b": 2.0
                                 })
        self.assertEqual(list(ax_client.experiment.metrics.keys()), ["b", "a"])

    @patch(
        "ax.core.experiment.Experiment.new_trial",
        side_effect=RuntimeError("cholesky_cpu error - bad matrix"),
    )
    def test_annotate_exception(self, _):
        ax_client = AxClient()
        ax_client.create_experiment(
            name="test_experiment",
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
            objective_name="a",
        )
        with self.assertRaisesRegex(
                expected_exception=RuntimeError,
                expected_regex="Cholesky errors typically occur",
        ):
            ax_client.get_next_trial()
示例#6
0
class ArrayModelBridgeTest(TestCase):
    @patch(
        f"{ModelBridge.__module__}.observations_from_data",
        autospec=True,
        return_value=([get_observation1()]),
    )
    @patch(
        f"{ModelBridge.__module__}.unwrap_observation_data",
        autospec=True,
        return_value=(2, 2),
    )
    @patch(
        f"{ModelBridge.__module__}.gen_arms",
        autospec=True,
        return_value=([Arm(parameters={})], {}),
    )
    @patch(
        f"{ModelBridge.__module__}.ModelBridge.predict",
        autospec=True,
        return_value=({
            "m": [1.0]
        }, {
            "m": {
                "m": [2.0]
            }
        }),
    )
    @patch(f"{ModelBridge.__module__}.ModelBridge._fit", autospec=True)
    @patch(
        f"{NumpyModel.__module__}.NumpyModel.best_point",
        return_value=(np.array([1, 2])),
        autospec=True,
    )
    @patch(
        f"{NumpyModel.__module__}.NumpyModel.gen",
        return_value=(np.array([[1, 2]]), np.array([1]), {}, []),
        autospec=True,
    )
    def test_best_point(
        self,
        _mock_gen,
        _mock_best_point,
        _mock_fit,
        _mock_predict,
        _mock_gen_arms,
        _mock_unwrap,
        _mock_obs_from_data,
    ):
        exp = Experiment(search_space=get_search_space_for_range_value(),
                         name="test")
        modelbridge = ArrayModelBridge(
            search_space=get_search_space_for_range_value(),
            model=NumpyModel(),
            transforms=[t1, t2],
            experiment=exp,
            data=Data(),
        )
        self.assertEqual(list(modelbridge.transforms.keys()),
                         ["Cast", "t1", "t2"])
        # _fit is mocked, which typically sets this.
        modelbridge.outcomes = ["a"]
        run = modelbridge.gen(
            n=1,
            optimization_config=OptimizationConfig(
                objective=Objective(metric=Metric("a"), minimize=False),
                outcome_constraints=[],
            ),
        )
        arm, predictions = run.best_arm_predictions
        self.assertEqual(arm.parameters, {})
        self.assertEqual(predictions[0], {"m": 1.0})
        self.assertEqual(predictions[1], {"m": {"m": 2.0}})
        # test check that optimization config is required
        with self.assertRaises(ValueError):
            run = modelbridge.gen(n=1, optimization_config=None)

    @patch(
        f"{ModelBridge.__module__}.observations_from_data",
        autospec=True,
        return_value=([get_observation1()]),
    )
    @patch(
        f"{ModelBridge.__module__}.unwrap_observation_data",
        autospec=True,
        return_value=(2, 2),
    )
    @patch(
        f"{ModelBridge.__module__}.gen_arms",
        autospec=True,
        return_value=[Arm(parameters={})],
    )
    @patch(
        f"{ModelBridge.__module__}.ModelBridge.predict",
        autospec=True,
        return_value=({
            "m": [1.0]
        }, {
            "m": {
                "m": [2.0]
            }
        }),
    )
    @patch(f"{ModelBridge.__module__}.ModelBridge._fit", autospec=True)
    @patch(
        f"{NumpyModel.__module__}.NumpyModel.feature_importances",
        return_value=np.array([[[1.0]], [[2.0]]]),
        autospec=True,
    )
    def test_importances(
        self,
        _mock_feature_importances,
        _mock_fit,
        _mock_predict,
        _mock_gen_arms,
        _mock_unwrap,
        _mock_obs_from_data,
    ):
        exp = Experiment(search_space=get_search_space_for_range_value(),
                         name="test")
        modelbridge = ArrayModelBridge(
            search_space=get_search_space_for_range_value(),
            model=NumpyModel(),
            transforms=[t1, t2],
            experiment=exp,
            data=Data(),
        )
        modelbridge.outcomes = ["a", "b"]
        self.assertEqual(modelbridge.feature_importances("a"), {"x": [1.0]})
        self.assertEqual(modelbridge.feature_importances("b"), {"x": [2.0]})

    @patch(
        f"{NumpyModel.__module__}.NumpyModel.gen",
        return_value=(
            np.array([[1, 2], [2, 3]]),
            np.array([1, 2]),
            {},
            [{
                "some_key": "some_value_0"
            }, {
                "some_key": "some_value_1"
            }],
        ),
        autospec=True,
    )
    @patch(f"{NumpyModel.__module__}.NumpyModel.update", autospec=True)
    @patch(f"{NumpyModel.__module__}.NumpyModel.fit", autospec=True)
    def test_candidate_metadata_propagation(self, mock_model_fit,
                                            mock_model_update, mock_model_gen):
        exp = get_branin_experiment(with_status_quo=True, with_batch=True)
        # Check that the metadata is correctly re-added to observation
        # features during `fit`.
        preexisting_batch_gr = exp.trials[0]._generator_run_structs[
            0].generator_run
        preexisting_batch_gr._candidate_metadata_by_arm_signature = {
            preexisting_batch_gr.arms[0].signature: {
                "preexisting_batch_cand_metadata": "some_value"
            }
        }
        modelbridge = ArrayModelBridge(
            search_space=exp.search_space,
            experiment=exp,
            model=NumpyModel(),
            data=get_branin_data(),
        )
        self.assertTrue(
            np.array_equal(
                mock_model_fit.call_args[1].get("Xs"),
                np.array([[list(exp.trials[0].arms[0].parameters.values())]]),
            ))
        self.assertEqual(
            mock_model_fit.call_args[1].get("candidate_metadata"),
            [[{
                "preexisting_batch_cand_metadata": "some_value"
            }]],
        )

        # Check that `gen` correctly propagates the metadata to the GR.
        gr = modelbridge.gen(n=1)
        self.assertEqual(
            gr.candidate_metadata_by_arm_signature,
            {
                gr.arms[0].signature: {
                    "some_key": "some_value_0"
                },
                gr.arms[1].signature: {
                    "some_key": "some_value_1"
                },
            },
        )
        # Check that the metadata is correctly re-added to observation
        # features during `update`.
        batch = exp.new_batch_trial(generator_run=gr)
        modelbridge.update(
            experiment=exp,
            new_data=get_branin_data(trial_indices=[batch.index]))
        self.assertTrue(
            np.array_equal(
                mock_model_update.call_args[1].get("Xs"),
                np.array(
                    [[list(exp.trials[0].arms[0].parameters.values()), [1,
                                                                        2]]]),
            ))
        self.assertEqual(
            mock_model_update.call_args[1].get("candidate_metadata"),
            [[
                {
                    "preexisting_batch_cand_metadata": "some_value"
                },
                # new data contained data just for arm '1_0', not for '1_1',
                # so we don't expect to see '{"some_key": "some_value_1"}'
                # in candidate metadata.
                {
                    "some_key": "some_value_0"
                },
            ]],
        )

        # Check that `None` candidate metadata is handled correctly.
        mock_model_gen.return_value = (
            np.array([[2, 4], [3, 5]]),
            np.array([1, 2]),
            None,
            {},
        )
        gr = modelbridge.gen(n=1)
        self.assertIsNone(gr.candidate_metadata_by_arm_signature)
        # Check that the metadata is correctly re-added to observation
        # features during `update`.
        batch = exp.new_batch_trial(generator_run=gr)
        modelbridge.update(
            experiment=exp,
            new_data=get_branin_data(trial_indices=[batch.index]))
        self.assertTrue(
            np.array_equal(
                mock_model_update.call_args[1].get("Xs"),
                np.array([[
                    list(exp.trials[0].arms[0].parameters.values()), [1, 2],
                    [2, 4]
                ]]),
            ))
        self.assertEqual(
            mock_model_update.call_args[1].get("candidate_metadata"),
            [[
                {
                    "preexisting_batch_cand_metadata": "some_value"
                },
                {
                    "some_key": "some_value_0"
                },
                {},
            ]],
        )

        # Check that no candidate metadata is handled correctly.
        exp = get_branin_experiment(with_status_quo=True)
        modelbridge = ArrayModelBridge(search_space=exp.search_space,
                                       experiment=exp,
                                       model=NumpyModel())
        # Hack in outcome names to bypass validation (since we instantiated model
        # without data).
        modelbridge.outcomes = modelbridge._metric_names = next(
            iter(exp.metrics))
        gr = modelbridge.gen(n=1)
        self.assertIsNone(
            mock_model_fit.call_args[1].get("candidate_metadata"))
        self.assertIsNone(gr.candidate_metadata_by_arm_signature)
        batch = exp.new_batch_trial(generator_run=gr)
        modelbridge.update(
            experiment=exp,
            new_data=get_branin_data(trial_indices=[batch.index]))
示例#7
0
class ArrayModelBridgeTest(TestCase):
    @patch(
        f"{ModelBridge.__module__}.observations_from_data",
        autospec=True,
        return_value=([get_observation1()]),
    )
    @patch(
        f"{ModelBridge.__module__}.unwrap_observation_data",
        autospec=True,
        return_value=(2, 2),
    )
    @patch(
        f"{ModelBridge.__module__}.gen_arms",
        autospec=True,
        return_value=[Arm(parameters={})],
    )
    @patch(
        f"{ModelBridge.__module__}.ModelBridge.predict",
        autospec=True,
        return_value=({
            "m": [1.0]
        }, {
            "m": {
                "m": [2.0]
            }
        }),
    )
    @patch(f"{ModelBridge.__module__}.ModelBridge._fit", autospec=True)
    @patch(
        f"{NumpyModel.__module__}.NumpyModel.best_point",
        return_value=(np.array([1, 2])),
        autospec=True,
    )
    @patch(
        f"{NumpyModel.__module__}.NumpyModel.gen",
        return_value=(np.array([[1, 2]]), np.array([1]), {}),
        autospec=True,
    )
    def test_best_point(
        self,
        _mock_gen,
        _mock_best_point,
        _mock_fit,
        _mock_predict,
        _mock_gen_arms,
        _mock_unwrap,
        _mock_obs_from_data,
    ):
        exp = Experiment(get_search_space_for_range_value(), "test")
        modelbridge = ArrayModelBridge(get_search_space_for_range_value(),
                                       NumpyModel(), [t1, t2], exp, 0)
        self.assertEqual(list(modelbridge.transforms.keys()),
                         ["Cast", "t1", "t2"])
        # _fit is mocked, which typically sets this.
        modelbridge.outcomes = ["a"]
        run = modelbridge.gen(
            n=1,
            optimization_config=OptimizationConfig(
                objective=Objective(metric=Metric("a"), minimize=False),
                outcome_constraints=[],
            ),
        )
        arm, predictions = run.best_arm_predictions
        self.assertEqual(arm.parameters, {})
        self.assertEqual(predictions[0], {"m": 1.0})
        self.assertEqual(predictions[1], {"m": {"m": 2.0}})
        # test check that optimization config is required
        with self.assertRaises(ValueError):
            run = modelbridge.gen(n=1, optimization_config=None)

    @patch(
        f"{ModelBridge.__module__}.observations_from_data",
        autospec=True,
        return_value=([get_observation1()]),
    )
    @patch(
        f"{ModelBridge.__module__}.unwrap_observation_data",
        autospec=True,
        return_value=(2, 2),
    )
    @patch(
        f"{ModelBridge.__module__}.gen_arms",
        autospec=True,
        return_value=[Arm(parameters={})],
    )
    @patch(
        f"{ModelBridge.__module__}.ModelBridge.predict",
        autospec=True,
        return_value=({
            "m": [1.0]
        }, {
            "m": {
                "m": [2.0]
            }
        }),
    )
    @patch(f"{ModelBridge.__module__}.ModelBridge._fit", autospec=True)
    @patch(
        f"{NumpyModel.__module__}.NumpyModel.feature_importances",
        return_value=np.array([[[1.0]], [[2.0]]]),
        autospec=True,
    )
    def test_importances(
        self,
        _mock_feature_importances,
        _mock_fit,
        _mock_predict,
        _mock_gen_arms,
        _mock_unwrap,
        _mock_obs_from_data,
    ):
        exp = Experiment(get_search_space_for_range_value(), "test")
        modelbridge = ArrayModelBridge(get_search_space_for_range_value(),
                                       NumpyModel(), [t1, t2], exp, 0)
        modelbridge.outcomes = ["a", "b"]
        self.assertEqual(modelbridge.feature_importances("a"), {"x": [1.0]})
        self.assertEqual(modelbridge.feature_importances("b"), {"x": [2.0]})
示例#8
0
class TestAxClient(TestCase):
    """Tests service-like API functionality."""

    def test_interruption(self) -> None:
        ax_client = AxClient()
        ax_client.create_experiment(
            name="test",
            parameters=[  # pyre-fixme[6]: expected union that should include
                {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
            ],
            objective_name="branin",
            minimize=True,
        )
        for i in range(6):
            parameterization, trial_index = ax_client.get_next_trial()
            self.assertFalse(  # There should be non-complete trials.
                all(t.status.is_terminal for t in ax_client.experiment.trials.values())
            )
            x1, x2 = parameterization.get("x1"), parameterization.get("x2")
            ax_client.complete_trial(
                trial_index,
                raw_data=checked_cast(
                    float, branin(checked_cast(float, x1), checked_cast(float, x2))
                ),
            )
            old_client = ax_client
            serialized = ax_client.to_json_snapshot()
            ax_client = AxClient.from_json_snapshot(serialized)
            self.assertEqual(len(ax_client.experiment.trials.keys()), i + 1)
            self.assertIsNot(ax_client, old_client)
            self.assertTrue(  # There should be no non-complete trials.
                all(t.status.is_terminal for t in ax_client.experiment.trials.values())
            )

    def test_default_generation_strategy(self) -> None:
        """Test that Sobol+GPEI is used if no GenerationStrategy is provided."""
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[  # pyre-fixme[6]: expected union that should include
                {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
            ],
            objective_name="branin",
            minimize=True,
        )
        self.assertEqual(
            [s.model for s in not_none(ax_client.generation_strategy)._steps],
            [Models.SOBOL, Models.GPEI],
        )
        with self.assertRaisesRegex(ValueError, ".* no trials."):
            ax_client.get_optimization_trace(objective_optimum=branin.fmin)
        for i in range(6):
            parameterization, trial_index = ax_client.get_next_trial()
            x1, x2 = parameterization.get("x1"), parameterization.get("x2")
            ax_client.complete_trial(
                trial_index,
                raw_data={
                    "branin": (
                        checked_cast(
                            float,
                            branin(checked_cast(float, x1), checked_cast(float, x2)),
                        ),
                        0.0,
                    )
                },
                sample_size=i,
            )
            if i < 5:
                with self.assertRaisesRegex(ValueError, "Could not obtain contour"):
                    ax_client.get_contour_plot(param_x="x1", param_y="x2")
        ax_client.get_optimization_trace(objective_optimum=branin.fmin)
        ax_client.get_contour_plot()
        self.assertIn("x1", ax_client.get_trials_data_frame())
        self.assertIn("x2", ax_client.get_trials_data_frame())
        self.assertIn("branin", ax_client.get_trials_data_frame())
        self.assertEqual(len(ax_client.get_trials_data_frame()), 6)
        # Test that Sobol is chosen when all parameters are choice.
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[  # pyre-fixme[6]: expected union that should include
                {"name": "x1", "type": "choice", "values": [1, 2, 3]},
                {"name": "x2", "type": "choice", "values": [1, 2, 3]},
            ]
        )
        self.assertEqual(
            [s.model for s in not_none(ax_client.generation_strategy)._steps],
            [Models.SOBOL],
        )
        self.assertEqual(ax_client.get_recommended_max_parallelism(), [(-1, -1)])
        self.assertTrue(ax_client.get_trials_data_frame().empty)

    def test_create_experiment(self) -> None:
        """Test basic experiment creation."""
        ax_client = AxClient(
            GenerationStrategy(steps=[GenerationStep(model=Models.SOBOL, num_arms=30)])
        )
        with self.assertRaisesRegex(ValueError, "Experiment not set on Ax client"):
            ax_client.experiment
        ax_client.create_experiment(
            name="test_experiment",
            parameters=[
                {
                    "name": "x1",
                    "type": "range",
                    "bounds": [0.001, 0.1],
                    "value_type": "float",
                    "log_scale": True,
                },
                {
                    "name": "x2",
                    "type": "choice",
                    "values": [1, 2, 3],
                    "value_type": "int",
                    "is_ordered": True,
                },
                {"name": "x3", "type": "fixed", "value": 2, "value_type": "int"},
                {
                    "name": "x4",
                    "type": "range",
                    "bounds": [1.0, 3.0],
                    "value_type": "int",
                },
                {
                    "name": "x5",
                    "type": "choice",
                    "values": ["one", "two", "three"],
                    "value_type": "str",
                },
                {
                    "name": "x6",
                    "type": "range",
                    "bounds": [1.0, 3.0],
                    "value_type": "int",
                },
            ],
            objective_name="test_objective",
            minimize=True,
            outcome_constraints=["some_metric >= 3", "some_metric <= 4.0"],
            parameter_constraints=["x4 <= x6"],
        )
        assert ax_client._experiment is not None
        self.assertEqual(ax_client._experiment, ax_client.experiment)
        self.assertEqual(
            ax_client._experiment.search_space.parameters["x1"],
            RangeParameter(
                name="x1",
                parameter_type=ParameterType.FLOAT,
                lower=0.001,
                upper=0.1,
                log_scale=True,
            ),
        )
        self.assertEqual(
            ax_client._experiment.search_space.parameters["x2"],
            ChoiceParameter(
                name="x2",
                parameter_type=ParameterType.INT,
                values=[1, 2, 3],
                is_ordered=True,
            ),
        )
        self.assertEqual(
            ax_client._experiment.search_space.parameters["x3"],
            FixedParameter(name="x3", parameter_type=ParameterType.INT, value=2),
        )
        self.assertEqual(
            ax_client._experiment.search_space.parameters["x4"],
            RangeParameter(
                name="x4", parameter_type=ParameterType.INT, lower=1.0, upper=3.0
            ),
        )
        self.assertEqual(
            ax_client._experiment.search_space.parameters["x5"],
            ChoiceParameter(
                name="x5",
                parameter_type=ParameterType.STRING,
                values=["one", "two", "three"],
            ),
        )
        self.assertEqual(
            ax_client._experiment.optimization_config.outcome_constraints[0],
            OutcomeConstraint(
                metric=Metric(name="some_metric"),
                op=ComparisonOp.GEQ,
                bound=3.0,
                relative=False,
            ),
        )
        self.assertEqual(
            ax_client._experiment.optimization_config.outcome_constraints[1],
            OutcomeConstraint(
                metric=Metric(name="some_metric"),
                op=ComparisonOp.LEQ,
                bound=4.0,
                relative=False,
            ),
        )
        self.assertTrue(ax_client._experiment.optimization_config.objective.minimize)

    def test_constraint_same_as_objective(self):
        """Check that we do not allow constraints on the objective metric."""
        ax_client = AxClient(
            GenerationStrategy(steps=[GenerationStep(model=Models.SOBOL, num_arms=30)])
        )
        with self.assertRaises(ValueError):
            ax_client.create_experiment(
                name="test_experiment",
                parameters=[
                    {"name": "x3", "type": "fixed", "value": 2, "value_type": "int"}
                ],
                objective_name="test_objective",
                outcome_constraints=["test_objective >= 3"],
            )

    def test_raw_data_format(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
            ],
            minimize=True,
        )
        for _ in range(6):
            parameterization, trial_index = ax_client.get_next_trial()
            x1, x2 = parameterization.get("x1"), parameterization.get("x2")
            ax_client.complete_trial(trial_index, raw_data=(branin(x1, x2), 0.0))
        with self.assertRaisesRegex(ValueError, "Raw data has an invalid type"):
            ax_client.complete_trial(trial_index, raw_data="invalid_data")

    def test_raw_data_format_with_fidelities(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "x2", "type": "range", "bounds": [0.0, 1.0]},
            ],
            minimize=True,
        )
        for _ in range(6):
            parameterization, trial_index = ax_client.get_next_trial()
            x1, x2 = parameterization.get("x1"), parameterization.get("x2")
            ax_client.complete_trial(
                trial_index,
                raw_data=[
                    ({"x2": x2 / 2.0}, {"objective": (branin(x1, x2 / 2.0), 0.0)}),
                    ({"x2": x2}, {"objective": (branin(x1, x2), 0.0)}),
                ],
            )

    def test_keep_generating_without_data(self):
        # Check that normally numebr of arms to generate is enforced.
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
            ],
            minimize=True,
        )
        for _ in range(5):
            parameterization, trial_index = ax_client.get_next_trial()
        with self.assertRaisesRegex(ValueError, "All trials for current model"):
            ax_client.get_next_trial()
        # Check thatwith enforce_sequential_optimization off, we can keep
        # generating.
        ax_client = AxClient(enforce_sequential_optimization=False)
        ax_client.create_experiment(
            parameters=[
                {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
            ],
            minimize=True,
        )
        for _ in range(10):
            parameterization, trial_index = ax_client.get_next_trial()

    def test_trial_completion(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
            ],
            minimize=True,
        )
        params, idx = ax_client.get_next_trial()
        ax_client.complete_trial(trial_index=idx, raw_data={"objective": (0, 0.0)})
        self.assertEqual(ax_client.get_best_parameters()[0], params)
        params2, idx2 = ax_client.get_next_trial()
        ax_client.complete_trial(trial_index=idx2, raw_data=(-1, 0.0))
        self.assertEqual(ax_client.get_best_parameters()[0], params2)
        params3, idx3 = ax_client.get_next_trial()
        ax_client.complete_trial(
            trial_index=idx3, raw_data=-2, metadata={"dummy": "test"}
        )
        self.assertEqual(ax_client.get_best_parameters()[0], params3)
        self.assertEqual(
            ax_client.experiment.trials.get(2).run_metadata.get("dummy"), "test"
        )
        best_trial_values = ax_client.get_best_parameters()[1]
        self.assertEqual(best_trial_values[0], {"objective": -2.0})
        self.assertTrue(math.isnan(best_trial_values[1]["objective"]["objective"]))

    def test_start_and_end_time_in_trial_completion(self):
        start_time = current_timestamp_in_millis()
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
            ],
            minimize=True,
        )
        params, idx = ax_client.get_next_trial()
        ax_client.complete_trial(
            trial_index=idx,
            raw_data=1.0,
            metadata={
                "start_time": start_time,
                "end_time": current_timestamp_in_millis(),
            },
        )
        dat = ax_client.experiment.fetch_data().df
        self.assertGreater(dat["end_time"][0], dat["start_time"][0])

    def test_fail_on_batch(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
            ],
            minimize=True,
        )
        batch_trial = ax_client.experiment.new_batch_trial(
            generator_run=GeneratorRun(
                arms=[
                    Arm(parameters={"x1": 0, "x2": 1}),
                    Arm(parameters={"x1": 0, "x2": 1}),
                ]
            )
        )
        with self.assertRaises(NotImplementedError):
            ax_client.complete_trial(batch_trial.index, 0)

    def test_log_failure(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
            ],
            minimize=True,
        )
        _, idx = ax_client.get_next_trial()
        ax_client.log_trial_failure(idx, metadata={"dummy": "test"})
        self.assertTrue(ax_client.experiment.trials.get(idx).status.is_failed)
        self.assertEqual(
            ax_client.experiment.trials.get(idx).run_metadata.get("dummy"), "test"
        )

    def test_attach_trial_and_get_trial_parameters(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
            ],
            minimize=True,
        )
        params, idx = ax_client.attach_trial(parameters={"x1": 0, "x2": 1})
        ax_client.complete_trial(trial_index=idx, raw_data=5)
        self.assertEqual(ax_client.get_best_parameters()[0], params)
        self.assertEqual(
            ax_client.get_trial_parameters(trial_index=idx), {"x1": 0, "x2": 1}
        )
        with self.assertRaises(ValueError):
            ax_client.get_trial_parameters(
                trial_index=10
            )  # No trial #10 in experiment.

    def test_attach_trial_numpy(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
            ],
            minimize=True,
        )
        params, idx = ax_client.attach_trial(parameters={"x1": 0, "x2": 1})
        ax_client.complete_trial(trial_index=idx, raw_data=np.int32(5))
        self.assertEqual(ax_client.get_best_parameters()[0], params)

    def test_relative_oc_without_sq(self):
        """Must specify status quo to have relative outcome constraint."""
        ax_client = AxClient()
        with self.assertRaises(ValueError):
            ax_client.create_experiment(
                name="test_experiment",
                parameters=[
                    {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                    {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
                ],
                objective_name="test_objective",
                minimize=True,
                outcome_constraints=["some_metric <= 4.0%"],
            )

    @patch("ax.service.utils.dispatch.Models", FakeModels)
    def test_recommended_parallelism(self):
        ax_client = AxClient()
        with self.assertRaisesRegex(ValueError, "No generation strategy"):
            ax_client.get_recommended_max_parallelism()
        ax_client.create_experiment(
            parameters=[
                {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
            ],
            minimize=True,
        )
        self.assertEqual(ax_client.get_recommended_max_parallelism(), [(5, 5), (-1, 3)])
        self.assertEqual(
            run_trials_using_recommended_parallelism(
                ax_client, ax_client.get_recommended_max_parallelism(), 20
            ),
            0,
        )
        # With incorrect parallelism setting, the 'need more data' error should
        # still be raised.
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
            ],
            minimize=True,
        )
        with self.assertRaisesRegex(ValueError, "All trials for current model "):
            run_trials_using_recommended_parallelism(ax_client, [(6, 6), (-1, 3)], 20)

    @patch.dict(sys.modules, {"ax.storage.sqa_store.structs": None})
    def test_no_sqa(self):
        # Pretend we couldn't import sqa_store.structs (this could happen when
        # SQLAlchemy is not installed).
        patcher = patch("ax.service.ax_client.DBSettings", None)
        patcher.start()
        with self.assertRaises(ModuleNotFoundError):
            import ax_client.storage.sqa_store.structs  # noqa F401
        AxClient()  # Make sure we still can instantiate client w/o db settings.
        # Even with correctly typed DBSettings, `AxClient` instantiation should
        # fail here, because `DBSettings` are mocked to None in `ax_client`.
        db_settings = DBSettings()
        self.assertIsInstance(db_settings, DBSettings)
        with self.assertRaisesRegex(ValueError, "`db_settings` argument should "):
            AxClient(db_settings=db_settings)
        patcher.stop()
        # DBSettings should be defined in `ax_client` now, but incorrectly typed
        # `db_settings` argument should still make instantiation fail.
        with self.assertRaisesRegex(ValueError, "`db_settings` argument should "):
            AxClient(db_settings="badly_typed_db_settings")

    def test_plotting_validation(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {"name": "x3", "type": "fixed", "value": 2, "value_type": "int"}
            ]
        )
        with self.assertRaisesRegex(ValueError, ".* there are no trials"):
            ax_client.get_contour_plot()
        ax_client.get_next_trial()
        with self.assertRaisesRegex(ValueError, ".* less than 2 parameters"):
            ax_client.get_contour_plot()
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
            ]
        )
        ax_client.get_next_trial()
        with self.assertRaisesRegex(ValueError, "If `param_x` is provided"):
            ax_client.get_contour_plot(param_x="x2")
        with self.assertRaisesRegex(ValueError, "If `param_x` is provided"):
            ax_client.get_contour_plot(param_y="x2")
        with self.assertRaisesRegex(ValueError, 'Parameter "x3"'):
            ax_client.get_contour_plot(param_x="x3", param_y="x3")
        with self.assertRaisesRegex(ValueError, 'Parameter "x4"'):
            ax_client.get_contour_plot(param_x="x1", param_y="x4")
        with self.assertRaisesRegex(ValueError, 'Metric "nonexistent"'):
            ax_client.get_contour_plot(
                param_x="x1", param_y="x2", metric_name="nonexistent"
            )
        with self.assertRaisesRegex(ValueError, "Could not obtain contour"):
            ax_client.get_contour_plot(
                param_x="x1", param_y="x2", metric_name="objective"
            )

    def test_sqa_storage(self):
        init_test_engine_and_session_factory(force_init=True)
        config = SQAConfig()
        encoder = Encoder(config=config)
        decoder = Decoder(config=config)
        db_settings = DBSettings(encoder=encoder, decoder=decoder)
        ax_client = AxClient(db_settings=db_settings)
        ax_client.create_experiment(
            name="test_experiment",
            parameters=[
                {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
            ],
            minimize=True,
        )
        for _ in range(5):
            parameters, trial_index = ax_client.get_next_trial()
            ax_client.complete_trial(
                trial_index=trial_index, raw_data=branin(*parameters.values())
            )
        gs = ax_client.generation_strategy
        ax_client = AxClient(db_settings=db_settings)
        ax_client.load_experiment_from_database("test_experiment")
        self.assertEqual(gs, ax_client.generation_strategy)
        with self.assertRaises(ValueError):
            # Overwriting existing experiment.
            ax_client.create_experiment(
                name="test_experiment",
                parameters=[
                    {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                    {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
                ],
                minimize=True,
            )
        # Overwriting existing experiment with overwrite flag.
        ax_client.create_experiment(
            name="test_experiment",
            parameters=[{"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}],
            overwrite_existing_experiment=True,
        )
        # There should be no trials, as we just put in a fresh experiment.
        self.assertEqual(len(ax_client.experiment.trials), 0)

    def test_fixed_random_seed_reproducibility(self):
        ax_client = AxClient(random_seed=239)
        ax_client.create_experiment(
            parameters=[
                {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
            ]
        )
        for _ in range(5):
            params, idx = ax_client.get_next_trial()
            ax_client.complete_trial(idx, branin(params.get("x1"), params.get("x2")))
        trial_parameters_1 = [
            t.arm.parameters for t in ax_client.experiment.trials.values()
        ]
        ax_client = AxClient(random_seed=239)
        ax_client.create_experiment(
            parameters=[
                {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
            ]
        )
        for _ in range(5):
            params, idx = ax_client.get_next_trial()
            ax_client.complete_trial(idx, branin(params.get("x1"), params.get("x2")))
        trial_parameters_2 = [
            t.arm.parameters for t in ax_client.experiment.trials.values()
        ]
        self.assertEqual(trial_parameters_1, trial_parameters_2)

    def test_init_position_saved(self):
        ax_client = AxClient(random_seed=239)
        ax_client.create_experiment(
            parameters=[
                {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
            ],
            name="sobol_init_position_test",
        )
        for _ in range(4):
            # For each generated trial, snapshot the client before generating it,
            # then recreate client, regenerate the trial and compare the trial
            # generated before and after snapshotting. If the state of Sobol is
            # recorded correctly, the newly generated trial will be the same as
            # the one generated before the snapshotting.
            serialized = ax_client.to_json_snapshot()
            params, idx = ax_client.get_next_trial()
            ax_client = AxClient.from_json_snapshot(serialized)
            with self.subTest(ax=ax_client, params=params, idx=idx):
                new_params, new_idx = ax_client.get_next_trial()
                self.assertEqual(params, new_params)
                self.assertEqual(idx, new_idx)
                self.assertEqual(
                    ax_client.experiment.trials[idx]._generator_run._model_kwargs[
                        "init_position"
                    ],
                    idx + 1,
                )
            ax_client.complete_trial(idx, branin(params.get("x1"), params.get("x2")))

    @patch(
        "ax.modelbridge.base.observations_from_data",
        autospec=True,
        return_value=([get_observation1()]),
    )
    @patch(
        "ax.modelbridge.random.RandomModelBridge.get_training_data",
        autospec=True,
        return_value=([get_observation1()]),
    )
    @patch(
        "ax.modelbridge.random.RandomModelBridge._predict",
        autospec=True,
        return_value=[get_observation1trans().data],
    )
    def test_get_model_predictions(self, _predict, _tr_data, _obs_from_data):
        ax_client = AxClient()
        ax_client.create_experiment(
            name="test_experiment",
            parameters=[
                {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "y", "type": "range", "bounds": [0.0, 15.0]},
            ],
            minimize=True,
            objective_name="a",
        )
        ax_client.get_next_trial()
        ax_client.experiment.trials[0].arm._name = "1_1"
        self.assertEqual(ax_client.get_model_predictions(), {0: {"a": (9.0, 1.0)}})

    def test_deprecated_save_load_method_errors(self):
        ax_client = AxClient()
        with self.assertRaises(NotImplementedError):
            ax_client.save()
        with self.assertRaises(NotImplementedError):
            ax_client.load()
        with self.assertRaises(NotImplementedError):
            ax_client.load_experiment("test_experiment")

    def test_find_last_trial_with_parameterization(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            name="test_experiment",
            parameters=[
                {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "y", "type": "range", "bounds": [0.0, 15.0]},
            ],
            minimize=True,
            objective_name="a",
        )
        params, trial_idx = ax_client.get_next_trial()
        found_trial_idx = ax_client._find_last_trial_with_parameterization(
            parameterization=params
        )
        self.assertEqual(found_trial_idx, trial_idx)
        # Check that it's indeed the _last_ trial with params that is found.
        _, new_trial_idx = ax_client.attach_trial(parameters=params)
        found_trial_idx = ax_client._find_last_trial_with_parameterization(
            parameterization=params
        )
        self.assertEqual(found_trial_idx, new_trial_idx)
        with self.assertRaisesRegex(ValueError, "No .* matches"):
            found_trial_idx = ax_client._find_last_trial_with_parameterization(
                parameterization={k: v + 1.0 for k, v in params.items()}
            )

    def test_verify_parameterization(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            name="test_experiment",
            parameters=[
                {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "y", "type": "range", "bounds": [0.0, 15.0]},
            ],
            minimize=True,
            objective_name="a",
        )
        params, trial_idx = ax_client.get_next_trial()
        self.assertTrue(
            ax_client.verify_trial_parameterization(
                trial_index=trial_idx, parameterization=params
            )
        )
        # Make sure it still works if ordering in the parameterization is diff.
        self.assertTrue(
            ax_client.verify_trial_parameterization(
                trial_index=trial_idx,
                parameterization={k: params[k] for k in reversed(list(params.keys()))},
            )
        )
        self.assertFalse(
            ax_client.verify_trial_parameterization(
                trial_index=trial_idx,
                parameterization={k: v + 1.0 for k, v in params.items()},
            )
        )