def testExperimentUpdateTrial(self): save_experiment(self.experiment) trial = self.experiment.trials[0] trial.mark_staged() update_trial(experiment=self.experiment, trial=trial) loaded_experiment = load_experiment(self.experiment.name) self.assertEqual(trial, loaded_experiment.trials[0]) trial._run_metadata = {"foo": "bar"} update_trial(experiment=self.experiment, trial=trial) loaded_experiment = load_experiment(self.experiment.name) self.assertEqual(trial, loaded_experiment.trials[0]) self.experiment.attach_data(get_data(trial_index=trial.index)) update_trial(experiment=self.experiment, trial=trial) loaded_experiment = load_experiment(self.experiment.name) self.assertEqual(self.experiment, loaded_experiment) trial = self.experiment.new_batch_trial( generator_run=get_generator_run()) save_new_trial(experiment=self.experiment, trial=trial) self.experiment.attach_data(get_data(trial_index=trial.index)) update_trial(experiment=self.experiment, trial=trial) loaded_experiment = load_experiment(self.experiment.name) self.assertEqual(self.experiment, loaded_experiment)
def testUpdateGenerationStrategyIncrementally(self): experiment = get_branin_experiment() generation_strategy = choose_generation_strategy( experiment.search_space) save_experiment(experiment=experiment) save_generation_strategy(generation_strategy=generation_strategy) # add generator runs, save, reload generator_runs = [] for i in range(7): data = get_branin_data() if i > 0 else None gr = generation_strategy.gen(experiment, data=data) generator_runs.append(gr) trial = experiment.new_trial(generator_run=gr).mark_running( no_runner_required=True) trial.mark_completed() save_experiment(experiment=experiment) update_generation_strategy(generation_strategy=generation_strategy, generator_runs=generator_runs) loaded_generation_strategy = load_generation_strategy_by_experiment_name( experiment_name=experiment.name) self.assertEqual(generation_strategy._curr.index, loaded_generation_strategy._curr.index, 1) self.assertEqual(len(loaded_generation_strategy._generator_runs), 7)
def testSaveValidation(self): with self.assertRaises(ValueError): save_experiment(self.experiment.trials[0]) experiment = get_experiment_with_batch_trial() experiment.name = None with self.assertRaises(ValueError): save_experiment(experiment)
def testCopyDBIDsRepeatedArms(self): exp = get_experiment_with_batch_trial() exp.trials[0] save_experiment(exp) exp.new_batch_trial().add_arms_and_weights(exp.trials[0].arms) save_experiment(exp) self.assertNotEqual(exp.trials[0].arms[0].db_id, exp.trials[1].arms[0].db_id)
def testExperimentSaveAndLoad(self): for exp in [ self.experiment, get_experiment_with_multi_objective(), get_experiment_with_scalarized_objective(), ]: save_experiment(exp) loaded_experiment = load_experiment(exp.name) self.assertEqual(loaded_experiment, exp)
def testMTExperimentSaveAndLoad(self): experiment = get_multi_type_experiment(add_trials=True) save_experiment(experiment) loaded_experiment = load_experiment(experiment.name) self.assertEqual(loaded_experiment.default_trial_type, "type1") self.assertEqual(len(loaded_experiment._trial_type_to_runner), 2) self.assertEqual(loaded_experiment.metric_to_trial_type["m1"], "type1") self.assertEqual(loaded_experiment.metric_to_trial_type["m2"], "type2") self.assertEqual(loaded_experiment._metric_to_canonical_name["m2"], "m1") self.assertEqual(len(loaded_experiment.trials), 2)
def testExperimentNewTrialValidation(self): trial = self.experiment.new_batch_trial() with self.assertRaises(ValueError): # must save experiment first save_new_trial(experiment=self.experiment, trial=trial) save_experiment(self.experiment) with self.assertRaises(ValueError): # can't save new trial twice save_new_trial(experiment=self.experiment, trial=trial)
def testCopyDBIDsBatchTrialExp(self): exp1 = get_experiment_with_batch_trial() save_experiment(exp1) exp2 = load_experiment(exp1.name) self.assertEqual(exp1, exp2) # empty some of exp2 db_ids exp2.trials[0].db_id = None exp2.trials[0].generator_runs[0].arms[0].db_id = None # copy db_ids from exp1 to exp2 copy_db_ids(exp1, exp2) self.assertEqual(exp1, exp2)
def testExperimentUpdateTrialValidation(self): trial = self.experiment.trials[0] with self.assertRaises(ValueError): # must save experiment first update_trial(experiment=self.experiment, trial=trial) save_experiment(self.experiment) trial._index = 1 with self.assertRaises(ValueError): # has bad index update_trial(experiment=self.experiment, trial=trial)
def testCopyDBIDsDataExp(self): exp1 = get_experiment_with_data() save_experiment(exp1) exp2 = load_experiment(exp1.name) self.assertEqual(exp1, exp2) # empty some of exp2 db_ids data, _ = exp2.lookup_data_for_trial(0) data.db_id = None # copy db_ids from exp1 to exp2 copy_db_ids(exp1, exp2) self.assertEqual(exp1, exp2)
def testExperimentAbandonedArmUpdates(self): experiment = get_experiment_with_batch_trial() save_experiment(experiment) # one arm is already abandoned self.assertEqual(get_session().query(SQAAbandonedArm).count(), 1) trial = experiment.trials[0] trial.mark_arm_abandoned(trial.arms[1].name) save_experiment(experiment) self.assertEqual(get_session().query(SQAAbandonedArm).count(), 2) loaded_experiment = load_experiment(experiment.name) self.assertEqual(experiment, loaded_experiment)
def testExperimentTrackingMetricUpdates(self): experiment = get_experiment_with_batch_trial() save_experiment(experiment) self.assertEqual(get_session().query(SQAMetric).count(), len(experiment.metrics)) # update tracking metric # (should perform update in place) metric = Metric(name="tracking", lower_is_better=True) experiment.update_tracking_metric(metric) save_experiment(experiment) self.assertEqual(get_session().query(SQAMetric).count(), len(experiment.metrics)) # add tracking metric metric = Metric(name="tracking2") experiment.add_tracking_metric(metric) save_experiment(experiment) self.assertEqual(get_session().query(SQAMetric).count(), len(experiment.metrics)) # remove tracking metric # (old one should be deleted) experiment.remove_tracking_metric("tracking2") save_experiment(experiment) self.assertEqual(get_session().query(SQAMetric).count(), len(experiment.metrics)) loaded_experiment = load_experiment(experiment.name) self.assertEqual(experiment, loaded_experiment)
def testExperimentRunnerUpdates(self): experiment = get_experiment_with_batch_trial() save_experiment(experiment) # one runner on the batch self.assertEqual(get_session().query(SQARunner).count(), 1) # add runner to experiment runner = get_synthetic_runner() experiment.runner = runner save_experiment(experiment) self.assertEqual(get_session().query(SQARunner).count(), 2) # update runner # (should perform update in place) runner = get_synthetic_runner() runner.dummy_metadata = {"foo": "bar"} experiment.runner = runner save_experiment(experiment) self.assertEqual(get_session().query(SQARunner).count(), 2) # remove runner # (old one should be deleted) experiment.runner = None save_experiment(experiment) self.assertEqual(get_session().query(SQARunner).count(), 1) loaded_experiment = load_experiment(experiment.name) self.assertEqual(experiment, loaded_experiment)
def testExperimentTrialUpdates(self): experiment = get_experiment_with_batch_trial() save_experiment(experiment) self.assertEqual(get_session().query(SQATrial).count(), 1) self.assertEqual(get_session().query(SQARunner).count(), 1) # add trial trial = experiment.new_batch_trial() runner = get_synthetic_runner() trial.runner = runner save_experiment(experiment) self.assertEqual(get_session().query(SQATrial).count(), 2) self.assertEqual(get_session().query(SQARunner).count(), 2) # update trial's runner runner.dummy_metadata = "dummy metadata" trial.runner = runner save_experiment(experiment) self.assertEqual(get_session().query(SQATrial).count(), 2) self.assertEqual(get_session().query(SQARunner).count(), 2) trial.run() save_experiment(experiment) self.assertEqual(get_session().query(SQATrial).count(), 2) loaded_experiment = load_experiment(experiment.name) self.assertEqual(experiment, loaded_experiment)
def testExperimentNewTrial(self): save_experiment(self.experiment) trial = self.experiment.new_batch_trial() save_new_trial(experiment=self.experiment, trial=trial) loaded_experiment = load_experiment(self.experiment.name) self.assertEqual(len(loaded_experiment.trials), 2) self.assertEqual(trial, loaded_experiment.trials[1]) trial = self.experiment.new_batch_trial(generator_run=get_generator_run()) save_new_trial(experiment=self.experiment, trial=trial) loaded_experiment = load_experiment(self.experiment.name) self.assertEqual(len(loaded_experiment.trials), 3) self.assertEqual(trial, loaded_experiment.trials[2])
def testUpdateGenerationStrategy(self): generation_strategy = get_generation_strategy() save_generation_strategy(generation_strategy=generation_strategy) # Add data, save, reload generation_strategy._data = Data( df=pd.DataFrame.from_records([{ "metric_name": "foo", "mean": 1, "arm_name": "bar" }])) save_generation_strategy(generation_strategy=generation_strategy) loaded_generation_strategy = load_generation_strategy_by_id( gs_id=generation_strategy._db_id) self.assertEqual(generation_strategy, loaded_generation_strategy) experiment = get_branin_experiment() generation_strategy = get_generation_strategy() save_experiment(experiment) # add generator run, save, reload experiment.new_trial(generator_run=generation_strategy.gen(experiment)) save_generation_strategy(generation_strategy=generation_strategy) loaded_generation_strategy = load_generation_strategy_by_experiment_name( experiment_name=experiment.name) self.assertEqual(generation_strategy, loaded_generation_strategy) # add another generator run, save, reload experiment.new_trial(generator_run=generation_strategy.gen( experiment, new_data=get_branin_data())) save_generation_strategy(generation_strategy=generation_strategy) loaded_generation_strategy = load_generation_strategy_by_experiment_name( experiment_name=experiment.name) self.assertEqual(generation_strategy, loaded_generation_strategy) # make sure that we can update the experiment too experiment.description = "foobar" save_experiment(experiment) loaded_generation_strategy = load_generation_strategy_by_experiment_name( experiment_name=experiment.name) self.assertEqual(generation_strategy, loaded_generation_strategy) self.assertEqual(generation_strategy._experiment.description, experiment.description) self.assertEqual( generation_strategy._experiment.description, loaded_generation_strategy._experiment.description, )
def testEncodeDecodeGenerationStrategy(self): # Cannot load generation strategy before it has been saved with self.assertRaises(ValueError): load_generation_strategy_by_id(gs_id=0) # Check that we can encode and decode the generation strategy *before* # it has generated some trials and been updated with some data. generation_strategy = get_generation_strategy() # Check that we can save a generation strategy without an experiment # attached. save_generation_strategy(generation_strategy=generation_strategy) # Also try restoring this generation strategy by its ID in the DB. new_generation_strategy = load_generation_strategy_by_id( gs_id=generation_strategy._db_id ) self.assertEqual(generation_strategy, new_generation_strategy) self.assertIsNone(generation_strategy._experiment) # Cannot load generation strategy before it has been saved experiment = get_branin_experiment() save_experiment(experiment) with self.assertRaises(ValueError): load_generation_strategy_by_experiment_name(experiment_name=experiment.name) # Check that we can encode and decode the generation strategy *after* # it has generated some trials and been updated with some data. generation_strategy = new_generation_strategy experiment.new_trial(generation_strategy.gen(experiment=experiment)) experiment.new_trial( generation_strategy.gen(experiment, data=get_branin_data()) ) save_generation_strategy(generation_strategy=generation_strategy) save_experiment(experiment) # Try restoring the generation strategy using the experiment its # attached to. new_generation_strategy = load_generation_strategy_by_experiment_name( experiment_name=experiment.name ) # `_seen_trial_indices_by_status` attribute of a GS is not saved in DB, # so it will be None in the restored version of the GS. # Hackily removing it from the original GS to check equality. generation_strategy._seen_trial_indices_by_status = None self.assertEqual(generation_strategy, new_generation_strategy) self.assertIsInstance(new_generation_strategy._steps[0].model, Models) self.assertIsInstance(new_generation_strategy.model, ModelBridge) self.assertEqual(len(new_generation_strategy._generator_runs), 2) self.assertEqual(new_generation_strategy._experiment._name, experiment._name)
def testEncodeDecodeGenerationStrategy(self): # Cannot load generation strategy before it has been saved with self.assertRaises(ValueError): load_generation_strategy_by_id(gs_id=0) # Check that we can encode and decode the generation strategy *before* # it has generated some trials and been updated with some data. generation_strategy = get_generation_strategy() # Check that we can save a generation strategy without an experiment # attached. save_generation_strategy(generation_strategy=generation_strategy) # Also try restoring this generation strategy by its ID in the DB. new_generation_strategy = load_generation_strategy_by_id( gs_id=generation_strategy._db_id) self.assertEqual(generation_strategy, new_generation_strategy) self.assertIsNone(generation_strategy._experiment) self.assertEqual(len(generation_strategy._generated), 0) self.assertEqual(len(generation_strategy._observed), 0) # Cannot load generation strategy before it has been saved experiment = get_branin_experiment() save_experiment(experiment) with self.assertRaises(ValueError): load_generation_strategy_by_experiment_name( experiment_name=experiment.name) # Check that we can encode and decode the generation strategy *after* # it has generated some trials and been updated with some data. generation_strategy = new_generation_strategy experiment.new_trial(generator_run=generation_strategy.gen(experiment)) experiment.new_trial( generation_strategy.gen(experiment, new_data=get_branin_data())) self.assertGreater(len(generation_strategy._generated), 0) self.assertGreater(len(generation_strategy._observed), 0) save_generation_strategy(generation_strategy=generation_strategy) save_experiment(experiment) # Try restoring the generation strategy using the experiment its # attached to. new_generation_strategy = load_generation_strategy_by_experiment_name( experiment_name=experiment.name) self.assertEqual(generation_strategy, new_generation_strategy) self.assertIsInstance(new_generation_strategy._steps[0].model, Models) self.assertIsInstance(new_generation_strategy.model, ModelBridge) self.assertEqual(len(new_generation_strategy._generator_runs), 2) self.assertEqual(new_generation_strategy._experiment._name, experiment._name)
def testExperimentObjectiveThresholdUpdates(self): experiment = get_experiment_with_batch_trial() save_experiment(experiment) self.assertEqual(get_session().query(SQAMetric).count(), len(experiment.metrics)) # update objective threshold # (should perform update in place) optimization_config = get_multi_objective_optimization_config() objective_threshold = get_objective_threshold() optimization_config.objective_thresholds = [objective_threshold] experiment.optimization_config = optimization_config save_experiment(experiment) self.assertEqual(get_session().query(SQAMetric).count(), 6) # add outcome constraint outcome_constraint2 = OutcomeConstraint(metric=Metric(name="outcome"), op=ComparisonOp.GEQ, bound=-0.5) optimization_config.outcome_constraints = [ optimization_config.outcome_constraints[0], outcome_constraint2, ] experiment.optimization_config = optimization_config save_experiment(experiment) self.assertEqual(get_session().query(SQAMetric).count(), 7) # remove outcome constraint # (old one should become tracking metric) optimization_config.outcome_constraints = [] experiment.optimization_config = optimization_config save_experiment(experiment) self.assertEqual(get_session().query(SQAMetric).count(), 5) loaded_experiment = load_experiment(experiment.name) self.assertEqual(experiment, loaded_experiment) # Optimization config should correctly reload even with no # objective_thresholds optimization_config.objective_thresholds = [] save_experiment(experiment) self.assertEqual(get_session().query(SQAMetric).count(), 4) loaded_experiment = load_experiment(experiment.name) self.assertEqual(experiment, loaded_experiment)
def testRegistryAdditions(self): class MyRunner(Runner): def run(): pass def staging_required(): return False class MyMetric(Metric): pass register_metric(MyMetric) register_runner(MyRunner) experiment = get_experiment_with_batch_trial() experiment.runner = MyRunner() experiment.add_tracking_metric(MyMetric(name="my_metric")) save_experiment(experiment) loaded_experiment = load_experiment(experiment.name) self.assertEqual(loaded_experiment, experiment)
def testExperimentObjectiveUpdates(self): experiment = get_experiment_with_batch_trial() save_experiment(experiment) self.assertEqual(get_session().query(SQAMetric).count(), len(experiment.metrics)) # update objective # (should perform update in place) optimization_config = get_optimization_config() objective = get_objective() objective.minimize = True optimization_config.objective = objective experiment.optimization_config = optimization_config save_experiment(experiment) self.assertEqual(get_session().query(SQAMetric).count(), len(experiment.metrics)) # replace objective # (old one should become tracking metric) optimization_config.objective = Objective(metric=Metric( name="objective")) experiment.optimization_config = optimization_config save_experiment(experiment) self.assertEqual(get_session().query(SQAMetric).count(), len(experiment.metrics)) loaded_experiment = load_experiment(experiment.name) self.assertEqual(experiment, loaded_experiment)
def testExperimentGeneratorRunUpdates(self): experiment = get_experiment_with_batch_trial() save_experiment(experiment) # one main generator run, one for the status quo self.assertEqual(get_session().query(SQAGeneratorRun).count(), 2) # add a arm # this will create one wrapper generator run # this will also replace the status quo generator run, # since the weight of the status quo will have changed trial = experiment.trials[0] trial.add_arm(get_arm()) save_experiment(experiment) self.assertEqual(get_session().query(SQAGeneratorRun).count(), 3) generator_run = get_generator_run() # TODO[Lena, T46190605]: remove generator_run._model_key = None generator_run._model_kwargs = None generator_run._bridge_kwargs = None trial.add_generator_run(generator_run=generator_run, multiplier=0.5) save_experiment(experiment) self.assertEqual(get_session().query(SQAGeneratorRun).count(), 4) loaded_experiment = load_experiment(experiment.name) self.assertEqual(experiment, loaded_experiment)
def testExperimentOverwriting(self): save_experiment(self.experiment) exp = get_experiment_with_batch_trial() # hack because otherwise time_createds will be too close exp._time_created = exp.time_created + timedelta(seconds=1) with self.assertRaises(Exception): save_experiment(exp) save_experiment(exp, overwrite=True)
def test_copy_db_ids_none_search_space(self): exp1 = get_experiment_with_batch_trial() save_experiment(exp1) exp2 = load_experiment(exp1.name) self.assertEqual(exp1, exp2) # empty search_space of exp1 exp1._search_space = None # empty some of exp2 db_ids exp2.trials[0].db_id = None exp2.trials[0].generator_runs[0].arms[0].db_id = None with self.assertWarnsRegex( Warning, "Encountered two objects of different types", ): # copy db_ids from exp1 to exp2 copy_db_ids(exp1, exp2) # empty search space of exp2 for comparison exp2._search_space = None self.assertEqual(exp1, exp2)
def testExperimentParameterUpdates(self): experiment = get_experiment_with_batch_trial() save_experiment(experiment) self.assertEqual( get_session().query(SQAParameter).count(), len(experiment.search_space.parameters), ) # update a parameter # (should perform update in place) search_space = get_search_space() parameter = get_choice_parameter() parameter.add_values(["foobar"]) search_space.update_parameter(parameter) experiment.search_space = search_space save_experiment(experiment) self.assertEqual( get_session().query(SQAParameter).count(), len(experiment.search_space.parameters), ) # add a parameter parameter = RangeParameter(name="x1", parameter_type=ParameterType.FLOAT, lower=-5, upper=10) search_space.add_parameter(parameter) experiment.search_space = search_space save_experiment(experiment) self.assertEqual( get_session().query(SQAParameter).count(), len(experiment.search_space.parameters), ) # remove a parameter # (old one should be deleted) del search_space._parameters["x1"] experiment.search_space = search_space save_experiment(experiment) self.assertEqual( get_session().query(SQAParameter).count(), len(experiment.search_space.parameters), ) loaded_experiment = load_experiment(experiment.name) self.assertEqual(experiment, loaded_experiment)
def testExperimentParameterConstraintUpdates(self): experiment = get_experiment_with_batch_trial() save_experiment(experiment) self.assertEqual( get_session().query(SQAParameterConstraint).count(), # 3 len(experiment.search_space.parameter_constraints), # 3 ) # add a parameter constraint search_space = experiment.search_space existing_constraint = experiment.search_space.parameter_constraints[0] new_constraint = get_sum_constraint2() search_space.add_parameter_constraints([new_constraint]) experiment.search_space = search_space save_experiment(experiment) self.assertEqual( get_session().query(SQAParameterConstraint).count(), len(experiment.search_space.parameter_constraints), ) # update a parameter constraint # (since we don't have UIDs for these, we throw out the old one # and create a new one) new_constraint.bound = 5.0 search_space.set_parameter_constraints( [existing_constraint, new_constraint]) experiment.search_space = search_space save_experiment(experiment) self.assertEqual( get_session().query(SQAParameterConstraint).count(), len(experiment.search_space.parameter_constraints), ) # remove a parameter constraint # (old one should be deleted) search_space.set_parameter_constraints([new_constraint]) experiment.search_space = search_space save_experiment(experiment) self.assertEqual( get_session().query(SQAParameterConstraint).count(), len(experiment.search_space.parameter_constraints), ) loaded_experiment = load_experiment(experiment.name) self.assertEqual(experiment, loaded_experiment)
def testExperimentOutcomeConstraintUpdates(self): experiment = get_experiment_with_batch_trial() save_experiment(experiment) self.assertEqual( get_session().query(SQAMetric).count(), len(experiment.metrics) ) # update outcome constraint # (should perform update in place) optimization_config = get_optimization_config() outcome_constraint = get_outcome_constraint() outcome_constraint.bound = -1.0 optimization_config.outcome_constraints = [outcome_constraint] experiment.optimization_config = optimization_config save_experiment(experiment) self.assertEqual( get_session().query(SQAMetric).count(), len(experiment.metrics) ) # add outcome constraint outcome_constraint2 = OutcomeConstraint( metric=Metric(name="outcome"), op=ComparisonOp.GEQ, bound=-0.5 ) optimization_config.outcome_constraints = [ outcome_constraint, outcome_constraint2, ] experiment.optimization_config = optimization_config save_experiment(experiment) self.assertEqual( get_session().query(SQAMetric).count(), len(experiment.metrics) ) # remove outcome constraint # (old one should become tracking metric) optimization_config.outcome_constraints = [outcome_constraint] experiment.optimization_config = optimization_config save_experiment(experiment) self.assertEqual( get_session().query(SQAMetric).count(), len(experiment.metrics) ) loaded_experiment = load_experiment(experiment.name) self.assertEqual(experiment, loaded_experiment)
def testUpdateGenerationStrategy(self): generation_strategy = get_generation_strategy() save_generation_strategy(generation_strategy=generation_strategy) experiment = get_branin_experiment() generation_strategy = get_generation_strategy() save_experiment(experiment) # add generator run, save, reload experiment.new_trial(generator_run=generation_strategy.gen(experiment)) save_generation_strategy(generation_strategy=generation_strategy) loaded_generation_strategy = load_generation_strategy_by_experiment_name( experiment_name=experiment.name ) # `_seen_trial_indices_by_status` attribute of a GS is not saved in DB, # so it will be None in the restored version of the GS. # Hackily removing it from the original GS to check equality. generation_strategy._seen_trial_indices_by_status = None self.assertEqual(generation_strategy, loaded_generation_strategy) # add another generator run, save, reload experiment.new_trial( generator_run=generation_strategy.gen(experiment, data=get_branin_data()) ) save_generation_strategy(generation_strategy=generation_strategy) save_experiment(experiment) loaded_generation_strategy = load_generation_strategy_by_experiment_name( experiment_name=experiment.name ) # `_seen_trial_indices_by_status` attribute of a GS is not saved in DB, # so it will be None in the restored version of the GS. # Hackily removing it from the original GS to check equality. generation_strategy._seen_trial_indices_by_status = None self.assertEqual(generation_strategy, loaded_generation_strategy) # make sure that we can update the experiment too experiment.description = "foobar" save_experiment(experiment) loaded_generation_strategy = load_generation_strategy_by_experiment_name( experiment_name=experiment.name ) self.assertEqual(generation_strategy, loaded_generation_strategy) self.assertEqual( generation_strategy._experiment.description, experiment.description ) self.assertEqual( generation_strategy._experiment.description, loaded_generation_strategy._experiment.description, )
def testExperimentUpdates(self): experiment = get_experiment_with_batch_trial() save_experiment(experiment) self.assertEqual(get_session().query(SQAExperiment).count(), 1) # update experiment # (should perform update in place) experiment.description = "foobar" save_experiment(experiment) self.assertEqual(get_session().query(SQAExperiment).count(), 1) experiment.status_quo = Arm(parameters={"w": 0.0, "x": 1, "y": "y", "z": True}) save_experiment(experiment) self.assertEqual(get_session().query(SQAExperiment).count(), 1) loaded_experiment = load_experiment(experiment.name) self.assertEqual(experiment, loaded_experiment)
def testUpdateGenerationStrategy(self): generation_strategy = get_generation_strategy() save_generation_strategy(generation_strategy=generation_strategy) experiment = get_branin_experiment() generation_strategy = get_generation_strategy() save_experiment(experiment) # add generator run, save, reload experiment.new_trial(generator_run=generation_strategy.gen(experiment)) save_generation_strategy(generation_strategy=generation_strategy) loaded_generation_strategy = load_generation_strategy_by_experiment_name( experiment_name=experiment.name) self.assertEqual(generation_strategy, loaded_generation_strategy) # add another generator run, save, reload experiment.new_trial(generator_run=generation_strategy.gen( experiment, data=get_branin_data())) save_generation_strategy(generation_strategy=generation_strategy) save_experiment(experiment) loaded_generation_strategy = load_generation_strategy_by_experiment_name( experiment_name=experiment.name) # During restoration of generation strategy's model from its last generator # run, we set `_seen_trial_indices_by_status` to that of the experiment, # from which we are grabbing the data to restore the model with. When the # experiment was updated more recently than the last `gen` from generation # strategy, the generation strategy prior to save might not have 'seen' # some recently added trials, so we update the mappings to match and check # that the generation strategies are equal otherwise. generation_strategy._seen_trial_indices_by_status[ TrialStatus.CANDIDATE].add(1) self.assertEqual(generation_strategy, loaded_generation_strategy) # make sure that we can update the experiment too experiment.description = "foobar" save_experiment(experiment) loaded_generation_strategy = load_generation_strategy_by_experiment_name( experiment_name=experiment.name) self.assertEqual(generation_strategy, loaded_generation_strategy) self.assertEqual(generation_strategy._experiment.description, experiment.description) self.assertEqual( generation_strategy._experiment.description, loaded_generation_strategy._experiment.description, )