예제 #1
0
 def testBestArm(self):
     generator_run = GeneratorRun(
         arms=self.arms,
         weights=self.weights,
         optimization_config=get_optimization_config(),
         search_space=get_search_space(),
         best_arm_predictions=(self.arms[0], ({"a": 1.0}, {"a": {"a": 2.0}})),
     )
     self.assertEqual(
         generator_run.best_arm_predictions,
         (self.arms[0], ({"a": 1.0}, {"a": {"a": 2.0}})),
     )
예제 #2
0
 def testTrackingMetricsMerge(self):
     # Tracking and optimization metrics should get merged
     # m1 is on optimization_config while m3 is not
     exp = Experiment(
         name="test2",
         search_space=get_search_space(),
         optimization_config=get_optimization_config(),
         tracking_metrics=[Metric(name="m1"),
                           Metric(name="m3")],
     )
     self.assertEqual(
         len(exp.optimization_config.metrics) + 1, len(exp.metrics))
예제 #3
0
def get_generation_strategy(
        with_experiment: bool = False,
        with_callable_model_kwarg: bool = True) -> GenerationStrategy:
    gs = choose_generation_strategy(search_space=get_search_space(),
                                    should_deduplicate=True)
    if with_experiment:
        gs._experiment = get_experiment()
    fake_func = get_experiment
    if with_callable_model_kwarg:
        # pyre-ignore[16]: testing hack to test serialization of callable kwargs
        # in generation steps.
        gs._steps[0].model_kwargs["model_constructor"] = fake_func
    return gs
예제 #4
0
 def testEmptyMetrics(self):
     empty_experiment = Experiment(name="test_experiment",
                                   search_space=get_search_space())
     self.assertEqual(empty_experiment.num_trials, 0)
     with self.assertRaises(ValueError):
         empty_experiment.fetch_data()
     batch = empty_experiment.new_batch_trial()
     self.assertEqual(empty_experiment.num_trials, 1)
     with self.assertRaises(ValueError):
         batch.fetch_data()
     empty_experiment.add_tracking_metric(Metric(name="some_metric"))
     empty_experiment.attach_data(get_data())
     self.assertFalse(empty_experiment.fetch_data().df.empty)
예제 #5
0
 def testModelPredictions(self):
     self.assertEqual(self.unweighted_run.model_predictions, get_model_predictions())
     self.assertEqual(
         self.unweighted_run.model_predictions_by_arm,
         get_model_predictions_per_arm(),
     )
     run_no_model_predictions = GeneratorRun(
         arms=self.arms,
         weights=self.weights,
         optimization_config=get_optimization_config(),
         search_space=get_search_space(),
     )
     self.assertIsNone(run_no_model_predictions.model_predictions)
     self.assertIsNone(run_no_model_predictions.model_predictions_by_arm)
예제 #6
0
 def testEmptyMetrics(self):
     empty_experiment = Experiment(name="test_experiment",
                                   search_space=get_search_space())
     self.assertEqual(empty_experiment.num_trials, 0)
     with self.assertRaises(ValueError):
         empty_experiment.fetch_data()
     batch = empty_experiment.new_batch_trial()
     batch.mark_running(no_runner_required=True)
     self.assertEqual(empty_experiment.num_trials, 1)
     with self.assertRaises(ValueError):
         batch.fetch_data()
     empty_experiment.add_tracking_metric(Metric(name="ax_test_metric"))
     self.assertTrue(empty_experiment.fetch_data().df.empty)
     empty_experiment.attach_data(get_data())
     batch.mark_completed()
     self.assertFalse(empty_experiment.fetch_data().df.empty)
예제 #7
0
파일: test_sqa_store.py 프로젝트: bitnot/Ax
    def testExperimentParameterUpdates(self):
        experiment = get_experiment_with_batch_trial()
        save_experiment(experiment)
        self.assertEqual(
            get_session().query(SQAParameter).count(),
            len(experiment.search_space.parameters),
        )

        # update a parameter
        # (should perform update in place)
        search_space = get_search_space()
        parameter = get_choice_parameter()
        parameter.add_values(["foobar"])
        search_space.update_parameter(parameter)
        experiment.search_space = search_space
        save_experiment(experiment)
        self.assertEqual(
            get_session().query(SQAParameter).count(),
            len(experiment.search_space.parameters),
        )

        # add a parameter
        parameter = RangeParameter(name="x1",
                                   parameter_type=ParameterType.FLOAT,
                                   lower=-5,
                                   upper=10)
        search_space.add_parameter(parameter)
        experiment.search_space = search_space
        save_experiment(experiment)
        self.assertEqual(
            get_session().query(SQAParameter).count(),
            len(experiment.search_space.parameters),
        )

        # remove a parameter
        # (old one should be deleted)
        del search_space._parameters["x1"]
        experiment.search_space = search_space
        save_experiment(experiment)
        self.assertEqual(
            get_session().query(SQAParameter).count(),
            len(experiment.search_space.parameters),
        )

        loaded_experiment = load_experiment(experiment.name)
        self.assertEqual(experiment, loaded_experiment)
예제 #8
0
def get_modelbridge(mock_gen_arms,
                    mock_observations_from_data,
                    status_quo_name: Optional[str] = None) -> ModelBridge:
    exp = get_experiment()
    modelbridge = ModelBridge(
        search_space=get_search_space(),
        model=FullFactorialGenerator(),
        experiment=exp,
        data=get_data(),
        status_quo_name=status_quo_name,
    )
    modelbridge._predict = mock.MagicMock(
        "ax.modelbridge.base.ModelBridge._predict",
        autospec=True,
        return_value=[get_observation().data],
    )
    return modelbridge
예제 #9
0
 def test_relativize_transform_requires_a_modelbridge_to_have_status_quo_data(self):
     sobol = Models.SOBOL(search_space=get_search_space())
     self.assertIsNone(sobol.status_quo)
     with self.assertRaisesRegex(ValueError, "status quo data"):
         Relativize(
             search_space=None,
             observation_features=[],
             observation_data=[],
             modelbridge=sobol,
         ).transform_observation_data(
             observation_data=[
                 ObservationData(
                     metric_names=["foo"],
                     means=np.array([2]),
                     covariance=np.array([[0.1]]),
                 )
             ],
             observation_features=[ObservationFeatures(parameters={"x": 1})],
         )
예제 #10
0
    def setUp(self):
        self.model_predictions = get_model_predictions()
        self.optimization_config = get_optimization_config()
        self.search_space = get_search_space()

        self.arms = get_arms()
        self.weights = [2, 1, 1]
        self.unweighted_run = GeneratorRun(
            arms=self.arms,
            optimization_config=self.optimization_config,
            search_space=self.search_space,
            model_predictions=self.model_predictions,
            fit_time=4.0,
            gen_time=10.0,
        )
        self.weighted_run = GeneratorRun(
            arms=self.arms,
            weights=self.weights,
            optimization_config=self.optimization_config,
            search_space=self.search_space,
            model_predictions=self.model_predictions,
        )
예제 #11
0
 def testBasicProperties(self):
     self.assertEqual(self.experiment.status_quo, get_status_quo())
     self.assertEqual(self.experiment.search_space, get_search_space())
     self.assertEqual(self.experiment.optimization_config,
                      get_optimization_config())
     self.assertEqual(self.experiment.is_test, True)
예제 #12
0
def get_generation_strategy(
        with_experiment: bool = False) -> GenerationStrategy:
    gs = choose_generation_strategy(search_space=get_search_space())
    if with_experiment:
        gs._experiment = get_experiment()
    return gs
예제 #13
0
def get_generation_strategy() -> GenerationStrategy:
    return choose_generation_strategy(search_space=get_search_space())