Esempio n. 1
0
 def test_attach_trial_and_get_trial_parameters(self):
     ax_client = AxClient()
     ax_client.create_experiment(
         parameters=[
             {
                 "name": "x",
                 "type": "range",
                 "bounds": [-5.0, 10.0]
             },
             {
                 "name": "y",
                 "type": "range",
                 "bounds": [0.0, 15.0]
             },
         ],
         minimize=True,
     )
     params, idx = ax_client.attach_trial(parameters={"x": 0.0, "y": 1.0})
     ax_client.complete_trial(trial_index=idx, raw_data=5)
     self.assertEqual(ax_client.get_best_parameters()[0], params)
     self.assertEqual(ax_client.get_trial_parameters(trial_index=idx), {
         "x": 0,
         "y": 1
     })
     with self.assertRaises(ValueError):
         ax_client.get_trial_parameters(
             trial_index=10)  # No trial #10 in experiment.
     with self.assertRaisesRegex(ValueError, ".* is of type"):
         ax_client.attach_trial({"x": 1, "y": 2})
Esempio n. 2
0
    def test_attach_trial_ttl_seconds(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "y", "type": "range", "bounds": [0.0, 15.0]},
            ],
            minimize=True,
        )
        params, idx = ax_client.attach_trial(
            parameters={"x": 0.0, "y": 1.0}, ttl_seconds=1
        )
        self.assertTrue(ax_client.experiment.trials.get(idx).status.is_running)
        time.sleep(1)  # Wait for TTL to elapse.
        self.assertTrue(ax_client.experiment.trials.get(idx).status.is_failed)
        # Also make sure we can no longer complete the trial as it is failed.
        with self.assertRaisesRegex(
            ValueError, ".* has been marked FAILED, so it no longer expects data."
        ):
            ax_client.complete_trial(trial_index=idx, raw_data=5)

        params2, idx2 = ax_client.attach_trial(
            parameters={"x": 0.0, "y": 1.0}, ttl_seconds=1
        )
        ax_client.complete_trial(trial_index=idx2, raw_data=5)
        self.assertEqual(ax_client.get_best_parameters()[0], params2)
        self.assertEqual(
            ax_client.get_trial_parameters(trial_index=idx2), {"x": 0, "y": 1}
        )
Esempio n. 3
0
import pickle
run_mode = "frozen_convolution_no_center_relu"
with open(f"hyperparameters_{run_mode}.pl", "rb") as handle:
    hyper = pickle.load(handle)

from ax import RangeParameter, ParameterType
from ax.service.ax_client import AxClient
from ax.plot.contour import plot_contour
from ax.plot.trace import optimization_trace_single_method
from ax.service.managed_loop import optimize
from ax.utils.notebook.plotting import render, init_notebook_plotting
from ax.utils.tutorials.cnn_utils import load_mnist, train, evaluate

# Initialize client
ax = AxClient()
ax = ax.from_json_snapshot(hyper["axclient"])
print(ax.get_trial_parameters(10))