Exemplo n.º 1
0
def test_trials_get_best_hyperparams_should_return_hyperparams_of_best_trial():
    # Given
    hp_trial_1 = HyperparameterSamples({'a': 2})
    trial_1 = Trial(hyperparams=hp_trial_1, main_metric_name=MAIN_METRIC_NAME)
    with trial_1:
        given_success_trial_validation_split(trial_1, best_score=0.2)

    hp_trial_2 = HyperparameterSamples({'b': 3})
    trial_2 = Trial(hyperparams=hp_trial_2, main_metric_name=MAIN_METRIC_NAME)
    with trial_2:
        given_success_trial_validation_split(trial_2, best_score=0.1)

    trials = Trials(trials=[trial_1, trial_2])

    # When
    best_hyperparams = trials.get_best_hyperparams()

    # Then
    assert best_hyperparams == hp_trial_2
Exemplo n.º 2
0
    def test_trials_get_best_hyperparams_should_return_hyperparams_of_best_trial(
            self):
        # Given
        trial_1 = self.trial
        with trial_1:
            self._given_success_trial_validation_split(trial_1, best_score=0.2)

        hp_trial_2 = HyperparameterSamples({'b': 3})
        trial_2 = Trial(trial_number=1,
                        save_trial_function=self.repo.save_trial,
                        hyperparams=hp_trial_2,
                        main_metric_name=MAIN_METRIC_NAME)
        with trial_2:
            self._given_success_trial_validation_split(trial_2, best_score=0.1)

        trials = Trials(trials=[trial_1, trial_2])

        # When
        best_hyperparams = trials.get_best_hyperparams()

        # Then
        assert best_hyperparams == hp_trial_2