예제 #1
0
    def load_all_trials(self, status: 'TRIAL_STATUS' = None) -> 'Trials':
        """
        Load all hyperparameter trials with their corresponding score.
        Reads all the saved trial json files, sorted by creation date.

        :return: (hyperparams, scores)
        """
        trials = Trials()

        files = glob.glob(os.path.join(self.cache_folder, '*.json'))

        # sort by created date:
        def getmtimens(filename):
            return os.stat(filename).st_mtime_ns

        files.sort(key=getmtimens)

        for base_path in files:
            with open(base_path) as f:
                try:
                    trial_json = json.load(f)
                except Exception as err:
                    print('invalid trial json file'.format(base_path))
                    print(traceback.format_exc())
                    continue

            if status is None or trial_json['status'] == status.value:
                trials.append(Trial.from_json(
                    update_trial_function=self.save_trial,
                    trial_json=trial_json
                ))

        return trials
예제 #2
0
    def __init__(self,
                 hyperparameter_selection_strategy=None,
                 cache_folder: str = None,
                 best_retrained_model_folder=None):
        HyperparamsRepository.__init__(
            self,
            hyperparameter_selection_strategy=hyperparameter_selection_strategy,
            cache_folder=cache_folder,
            best_retrained_model_folder=best_retrained_model_folder)
        self.cache_folder = cache_folder

        self.trials = Trials()
예제 #3
0
    def __init__(self, hyperparameter_selection_strategy=None, print_func: Callable = None, cache_folder: str = None,
                 best_retrained_model_folder=None):
        HyperparamsRepository.__init__(
            self,
            hyperparameter_selection_strategy=hyperparameter_selection_strategy,
            cache_folder=cache_folder,
            best_retrained_model_folder=best_retrained_model_folder
        )
        if print_func is None:
            print_func = print
        self.print_func = print_func
        self.cache_folder = cache_folder

        self.trials = Trials()
예제 #4
0
def test_trials_get_best_hyperparams_should_return_hyperparams_of_best_trial():
    # Given
    hp_trial_1 = HyperparameterSamples({'a': 2})
    trial_1 = Trial(hyperparams=hp_trial_1, main_metric_name=MAIN_METRIC_NAME)
    with trial_1:
        given_success_trial_validation_split(trial_1, best_score=0.2)

    hp_trial_2 = HyperparameterSamples({'b': 3})
    trial_2 = Trial(hyperparams=hp_trial_2, main_metric_name=MAIN_METRIC_NAME)
    with trial_2:
        given_success_trial_validation_split(trial_2, best_score=0.1)

    trials = Trials(trials=[trial_1, trial_2])

    # When
    best_hyperparams = trials.get_best_hyperparams()

    # Then
    assert best_hyperparams == hp_trial_2
예제 #5
0
    def test_trials_get_best_hyperparams_should_return_hyperparams_of_best_trial(
            self):
        # Given
        trial_1 = self.trial
        with trial_1:
            self._given_success_trial_validation_split(trial_1, best_score=0.2)

        hp_trial_2 = HyperparameterSamples({'b': 3})
        trial_2 = Trial(trial_number=1,
                        save_trial_function=self.repo.save_trial,
                        hyperparams=hp_trial_2,
                        main_metric_name=MAIN_METRIC_NAME)
        with trial_2:
            self._given_success_trial_validation_split(trial_2, best_score=0.1)

        trials = Trials(trials=[trial_1, trial_2])

        # When
        best_hyperparams = trials.get_best_hyperparams()

        # Then
        assert best_hyperparams == hp_trial_2