Beispiel #1
0
    def testExperimentGeneratorRunUpdates(self):
        experiment = get_experiment_with_batch_trial()
        save_experiment(experiment)
        # one main generator run, one for the status quo
        self.assertEqual(get_session().query(SQAGeneratorRun).count(), 2)

        # add a arm
        # this will create one wrapper generator run
        # this will also replace the status quo generator run,
        # since the weight of the status quo will have changed
        trial = experiment.trials[0]
        trial.add_arm(get_arm())
        save_experiment(experiment)
        self.assertEqual(get_session().query(SQAGeneratorRun).count(), 3)

        generator_run = get_generator_run()  # TODO[Lena, T46190605]: remove
        generator_run._model_key = None
        generator_run._model_kwargs = None
        generator_run._bridge_kwargs = None
        trial.add_generator_run(generator_run=generator_run, multiplier=0.5)
        save_experiment(experiment)
        self.assertEqual(get_session().query(SQAGeneratorRun).count(), 4)

        loaded_experiment = load_experiment(experiment.name)
        self.assertEqual(experiment, loaded_experiment)
Beispiel #2
0
 def testNumArmsNoDeduplication(self):
     exp = Experiment(name="test_experiment", search_space=get_search_space())
     arm = get_arm()
     exp.new_batch_trial().add_arm(arm)
     trial = exp.new_batch_trial().add_arm(arm)
     self.assertEqual(exp.sum_trial_sizes, 2)
     self.assertEqual(len(exp.arms_by_name), 1)
     trial.mark_arm_abandoned(trial.arms[0].name)
     self.assertEqual(exp.num_abandoned_arms, 1)
Beispiel #3
0
    def testEq(self):
        self.assertEqual(self.experiment, self.experiment)

        experiment2 = Experiment(
            name="test2",
            search_space=get_search_space(),
            optimization_config=get_optimization_config(),
            status_quo=get_arm(),
            description="test description",
        )
        self.assertNotEqual(self.experiment, experiment2)
Beispiel #4
0
    def testAddArm(self):
        self.assertEqual(len(self.batch.arms), len(self.arms))
        self.assertEqual(len(self.batch.generator_run_structs), 1)
        self.assertEqual(sum(self.batch.weights), sum(self.weights))

        arm_parameters = get_arm().parameters
        arm_parameters["w"] = 5.0
        self.batch.add_arm(Arm(arm_parameters), 3)

        self.assertEqual(self.batch.arms_by_name["0_2"], self.batch.arms[2])
        self.assertEqual(len(self.batch.arms), len(self.arms) + 1)
        self.assertEqual(len(self.batch.generator_run_structs), 2)
        self.assertEqual(sum(self.batch.weights), sum(self.weights) + 3)