def test_raw_data_format(self): ax_client = AxClient() ax_client.create_experiment( parameters=[ {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) for _ in range(6): parameterization, trial_index = ax_client.get_next_trial() x, y = parameterization.get("x"), parameterization.get("y") ax_client.complete_trial(trial_index, raw_data=(branin(x, y), 0.0)) with self.assertRaisesRegex(ValueError, "Raw data has an invalid type"): ax_client.update_trial_data(trial_index, raw_data="invalid_data")
def test_trial_completion(self): ax_client = AxClient() ax_client.create_experiment( parameters=[ { "name": "x", "type": "range", "bounds": [-5.0, 10.0] }, { "name": "y", "type": "range", "bounds": [0.0, 15.0] }, ], minimize=True, ) params, idx = ax_client.get_next_trial() # Can't update before completing. with self.assertRaisesRegex(ValueError, ".* not yet"): ax_client.update_trial_data(trial_index=idx, raw_data={"objective": (0, 0.0)}) ax_client.complete_trial(trial_index=idx, raw_data={"objective": (0, 0.0)}) # Cannot complete a trial twice, should use `update_trial_data`. with self.assertRaisesRegex(ValueError, ".* already been completed"): ax_client.complete_trial(trial_index=idx, raw_data={"objective": (0, 0.0)}) # Cannot update trial data with observation for a metric it already has. with self.assertRaisesRegex(ValueError, ".* contained an observation"): ax_client.update_trial_data(trial_index=idx, raw_data={"objective": (0, 0.0)}) # Same as above, except objective name should be getting inferred. with self.assertRaisesRegex(ValueError, ".* contained an observation"): ax_client.update_trial_data(trial_index=idx, raw_data=1.0) ax_client.update_trial_data(trial_index=idx, raw_data={"m1": (1, 0.0)}) metrics_in_data = ax_client.experiment.fetch_data( ).df["metric_name"].values self.assertIn("m1", metrics_in_data) self.assertIn("objective", metrics_in_data) self.assertEqual(ax_client.get_best_parameters()[0], params) params2, idy = ax_client.get_next_trial() ax_client.complete_trial(trial_index=idy, raw_data=(-1, 0.0)) self.assertEqual(ax_client.get_best_parameters()[0], params2) params3, idx3 = ax_client.get_next_trial() ax_client.complete_trial(trial_index=idx3, raw_data=-2, metadata={"dummy": "test"}) self.assertEqual(ax_client.get_best_parameters()[0], params3) self.assertEqual( ax_client.experiment.trials.get(2).run_metadata.get("dummy"), "test") best_trial_values = ax_client.get_best_parameters()[1] self.assertEqual(best_trial_values[0], {"objective": -2.0}) self.assertTrue( math.isnan(best_trial_values[1]["objective"]["objective"]))