예제 #1
0
 def setUp(self) -> None:
     self.experiment = get_experiment()
     self.arm = Arm({"x": 1, "y": "foo", "z": True, "w": 4})
     self.trial = self.experiment.new_trial(GeneratorRun([self.arm]))
     self.experiment_2 = get_experiment()
     self.batch_trial = self.experiment_2.new_batch_trial(
         GeneratorRun([self.arm]))
     self.obs_feat = ObservationFeatures.from_arm(arm=self.trial.arm,
                                                  trial_index=np.int64(
                                                      self.trial.index))
예제 #2
0
 def setUp(self) -> None:
     self.experiment = get_experiment()
     self.trial = self.experiment.new_trial(
         GeneratorRun([Arm({
             "x": 1,
             "y": "foo",
             "z": True,
             "w": 4
         })]))
예제 #3
0
 def setUp(self):
     self.experiment = get_experiment()
     self.experiment.status_quo = None
     self.batch = self.experiment.new_batch_trial()
     arms = get_arms()
     weights = get_weights()
     self.status_quo = arms[0]
     self.sq_weight = weights[0]
     self.arms = arms[1:]
     self.weights = weights[1:]
     self.batch.add_arms_and_weights(arms=self.arms, weights=self.weights)
예제 #4
0
    def test_save_load_experiment(self):
        exp = get_experiment()
        init_test_engine_and_session_factory()
        db_settings = DBSettings(url="sqlite://")
        save_experiment(exp, db_settings)
        load_experiment(exp.name, db_settings)

        simple_experiment = get_simple_experiment()
        save_experiment(simple_experiment, db_settings)
        with self.assertRaisesRegex(ValueError, "Service API only"):
            load_experiment(simple_experiment.name, db_settings)
예제 #5
0
    def testBasicBatchCreation(self):
        batch = self.experiment.new_batch_trial()
        self.assertEqual(len(self.experiment.trials), 1)
        self.assertEqual(self.experiment.trials[0], batch)

        # Try (and fail) to re-attach batch
        with self.assertRaises(ValueError):
            self.experiment._attach_trial(batch)

        # Try (and fail) to attach batch to another experiment
        with self.assertRaises(ValueError):
            new_exp = get_experiment()
            new_exp._attach_trial(batch)
예제 #6
0
    def test_save_load_experiment(self):
        exp = get_experiment()
        init_test_engine_and_session_factory(force_init=True)
        db_settings = DBSettings(
            encoder=Encoder(config=SQAConfig()),
            decoder=Decoder(config=SQAConfig()),
            creator=None,
        )
        save_experiment(exp, db_settings)
        load_experiment(exp.name, db_settings)

        simple_experiment = get_simple_experiment()
        save_experiment(simple_experiment, db_settings)
        with self.assertRaisesRegex(ValueError, "Service API only"):
            load_experiment(simple_experiment.name, db_settings)
예제 #7
0
def get_modelbridge(mock_gen_arms,
                    mock_observations_from_data,
                    status_quo_name: Optional[str] = None) -> ModelBridge:
    exp = get_experiment()
    modelbridge = ModelBridge(
        search_space=get_search_space(),
        model=FullFactorialGenerator(),
        experiment=exp,
        data=get_data(),
        status_quo_name=status_quo_name,
    )
    modelbridge._predict = mock.MagicMock(
        "ax.modelbridge.base.ModelBridge._predict",
        autospec=True,
        return_value=[get_observation().data],
    )
    return modelbridge
예제 #8
0
 def setUp(self):
     self.experiment = get_experiment()
예제 #9
0
 def setUp(self):
     self.experiment = get_experiment()
     self.trial = self.experiment.new_trial()
     self.arm = get_arms()[0]
     self.trial.add_arm(self.arm)
예제 #10
0
 def setUp(self):
     self.exp = get_experiment()
     init_test_engine_and_session_factory(force=True)
     self.db_settings = DBSettings(url="sqlite://")
     save_experiment(self.exp, self.db_settings)