def test_default_generation_strategy_continuous(self, _a, _b, _c, _d) -> None: """Test that Sobol+GPEI is used if no GenerationStrategy is provided.""" ax_client = AxClient() ax_client.create_experiment( parameters=[ # pyre-fixme[6]: expected union that should include { "name": "x", "type": "range", "bounds": [-5.0, 10.0] }, { "name": "y", "type": "range", "bounds": [0.0, 15.0] }, ], objective_name="a", minimize=True, ) self.assertEqual( [s.model for s in not_none(ax_client.generation_strategy)._steps], [Models.SOBOL, Models.GPEI], ) with self.assertRaisesRegex(ValueError, ".* no trials"): ax_client.get_optimization_trace(objective_optimum=branin.fmin) for i in range(6): parameterization, trial_index = ax_client.get_next_trial() x, y = parameterization.get("x"), parameterization.get("y") ax_client.complete_trial( trial_index, raw_data={ "a": ( checked_cast( float, branin(checked_cast(float, x), checked_cast(float, y)), ), 0.0, ) }, sample_size=i, ) self.assertEqual(ax_client.generation_strategy.model._model_key, "GPEI") ax_client.get_optimization_trace(objective_optimum=branin.fmin) ax_client.get_contour_plot() ax_client.get_feature_importances() trials_df = ax_client.get_trials_data_frame() self.assertIn("x", trials_df) self.assertIn("y", trials_df) self.assertIn("a", trials_df) self.assertEqual(len(trials_df), 6)
def test_plotting_validation(self): ax_client = AxClient() ax_client.create_experiment(parameters=[{ "name": "x3", "type": "fixed", "value": 2, "value_type": "int" }]) with self.assertRaisesRegex(ValueError, ".* there are no trials"): ax_client.get_contour_plot() with self.assertRaisesRegex(ValueError, ".* there are no trials"): ax_client.get_feature_importances() ax_client.get_next_trial() with self.assertRaisesRegex(ValueError, ".* less than 2 parameters"): ax_client.get_contour_plot() ax_client = AxClient() ax_client.create_experiment(parameters=[ { "name": "x", "type": "range", "bounds": [-5.0, 10.0] }, { "name": "y", "type": "range", "bounds": [0.0, 15.0] }, ]) ax_client.get_next_trial() with self.assertRaisesRegex(ValueError, "If `param_x` is provided"): ax_client.get_contour_plot(param_x="y") with self.assertRaisesRegex(ValueError, "If `param_x` is provided"): ax_client.get_contour_plot(param_y="y") with self.assertRaisesRegex(ValueError, 'Parameter "x3"'): ax_client.get_contour_plot(param_x="x3", param_y="x3") with self.assertRaisesRegex(ValueError, 'Parameter "x4"'): ax_client.get_contour_plot(param_x="x", param_y="x4") with self.assertRaisesRegex(ValueError, 'Metric "nonexistent"'): ax_client.get_contour_plot(param_x="x", param_y="y", metric_name="nonexistent") with self.assertRaisesRegex(UnsupportedPlotError, "Could not obtain contour"): ax_client.get_contour_plot(param_x="x", param_y="y", metric_name="objective") with self.assertRaisesRegex(ValueError, "Could not obtain feature"): ax_client.get_feature_importances()