Exemple #1
0
 def test_raw_data_format(self):
     ax_client = AxClient()
     ax_client.create_experiment(
         parameters=[
             {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "y", "type": "range", "bounds": [0.0, 15.0]},
         ],
         minimize=True,
     )
     for _ in range(6):
         parameterization, trial_index = ax_client.get_next_trial()
         x, y = parameterization.get("x"), parameterization.get("y")
         ax_client.complete_trial(trial_index, raw_data=(branin(x, y), 0.0))
     with self.assertRaisesRegex(ValueError, "Raw data has an invalid type"):
         ax_client.update_trial_data(trial_index, raw_data="invalid_data")
Exemple #2
0
 def test_trial_completion(self):
     ax_client = AxClient()
     ax_client.create_experiment(
         parameters=[
             {
                 "name": "x",
                 "type": "range",
                 "bounds": [-5.0, 10.0]
             },
             {
                 "name": "y",
                 "type": "range",
                 "bounds": [0.0, 15.0]
             },
         ],
         minimize=True,
     )
     params, idx = ax_client.get_next_trial()
     # Can't update before completing.
     with self.assertRaisesRegex(ValueError, ".* not yet"):
         ax_client.update_trial_data(trial_index=idx,
                                     raw_data={"objective": (0, 0.0)})
     ax_client.complete_trial(trial_index=idx,
                              raw_data={"objective": (0, 0.0)})
     # Cannot complete a trial twice, should use `update_trial_data`.
     with self.assertRaisesRegex(ValueError, ".* already been completed"):
         ax_client.complete_trial(trial_index=idx,
                                  raw_data={"objective": (0, 0.0)})
     # Cannot update trial data with observation for a metric it already has.
     with self.assertRaisesRegex(ValueError, ".* contained an observation"):
         ax_client.update_trial_data(trial_index=idx,
                                     raw_data={"objective": (0, 0.0)})
     # Same as above, except objective name should be getting inferred.
     with self.assertRaisesRegex(ValueError, ".* contained an observation"):
         ax_client.update_trial_data(trial_index=idx, raw_data=1.0)
     ax_client.update_trial_data(trial_index=idx, raw_data={"m1": (1, 0.0)})
     metrics_in_data = ax_client.experiment.fetch_data(
     ).df["metric_name"].values
     self.assertIn("m1", metrics_in_data)
     self.assertIn("objective", metrics_in_data)
     self.assertEqual(ax_client.get_best_parameters()[0], params)
     params2, idy = ax_client.get_next_trial()
     ax_client.complete_trial(trial_index=idy, raw_data=(-1, 0.0))
     self.assertEqual(ax_client.get_best_parameters()[0], params2)
     params3, idx3 = ax_client.get_next_trial()
     ax_client.complete_trial(trial_index=idx3,
                              raw_data=-2,
                              metadata={"dummy": "test"})
     self.assertEqual(ax_client.get_best_parameters()[0], params3)
     self.assertEqual(
         ax_client.experiment.trials.get(2).run_metadata.get("dummy"),
         "test")
     best_trial_values = ax_client.get_best_parameters()[1]
     self.assertEqual(best_trial_values[0], {"objective": -2.0})
     self.assertTrue(
         math.isnan(best_trial_values[1]["objective"]["objective"]))