def test_trials_get_best_hyperparams_should_return_hyperparams_of_best_trial(): # Given hp_trial_1 = HyperparameterSamples({'a': 2}) trial_1 = Trial(hyperparams=hp_trial_1, main_metric_name=MAIN_METRIC_NAME) with trial_1: given_success_trial_validation_split(trial_1, best_score=0.2) hp_trial_2 = HyperparameterSamples({'b': 3}) trial_2 = Trial(hyperparams=hp_trial_2, main_metric_name=MAIN_METRIC_NAME) with trial_2: given_success_trial_validation_split(trial_2, best_score=0.1) trials = Trials(trials=[trial_1, trial_2]) # When best_hyperparams = trials.get_best_hyperparams() # Then assert best_hyperparams == hp_trial_2
def test_trials_get_best_hyperparams_should_return_hyperparams_of_best_trial( self): # Given trial_1 = self.trial with trial_1: self._given_success_trial_validation_split(trial_1, best_score=0.2) hp_trial_2 = HyperparameterSamples({'b': 3}) trial_2 = Trial(trial_number=1, save_trial_function=self.repo.save_trial, hyperparams=hp_trial_2, main_metric_name=MAIN_METRIC_NAME) with trial_2: self._given_success_trial_validation_split(trial_2, best_score=0.1) trials = Trials(trials=[trial_1, trial_2]) # When best_hyperparams = trials.get_best_hyperparams() # Then assert best_hyperparams == hp_trial_2