Пример #1
0
    def _execute_trial(self, trial_number: int, repo_trial: Trial,
                       context: ExecutionContext,
                       validation_splits: List[Tuple[DataContainer,
                                                     DataContainer]]):
        for training_data_container, validation_data_container in validation_splits:
            p = copy.deepcopy(self.pipeline)
            p.update_hyperparams(repo_trial.hyperparams)
            repo_trial.set_hyperparams(p.get_hyperparams())

            with repo_trial.new_validation_split(p) as repo_trial_split:
                trial_split_description = self._get_trial_split_description(
                    repo_trial=repo_trial,
                    repo_trial_split=repo_trial_split,
                    validation_splits=validation_splits,
                    trial_number=trial_number)

                self.print_func(
                    'fitting trial {}'.format(trial_split_description))

                repo_trial_split = self.trainer.fit_trial_split(
                    trial_split=repo_trial_split,
                    train_data_container=training_data_container,
                    validation_data_container=validation_data_container,
                    context=context)

                repo_trial_split.set_success()

                self.print_func('success trial {} score: {}'.format(
                    trial_split_description,
                    repo_trial_split.get_validation_score()))

        return repo_trial_split
Пример #2
0
    def execute_trial(self,
                      pipeline: BaseStep,
                      trial_number: int,
                      repo_trial: Trial,
                      context: ExecutionContext,
                      validation_splits: List[Tuple[DataContainer,
                                                    DataContainer]],
                      n_trial: int,
                      delete_pipeline_on_completion: bool = True):
        """
        Train pipeline using the validation splitter.
        Track training, and validation metrics for each epoch.

        :param pipeline: pipeline to train on
        :param trial_number: trial number
        :param repo_trial: repo trial
        :param validation_splits: validation splits
        :param context: execution context
        :param n_trial: total number of trials that will be executed
        :param delete_pipeline_on_completion: bool to delete pipeline on completion or not
        :return: executed trial split
        """
        for training_data_container, validation_data_container in validation_splits:
            p = copy.deepcopy(pipeline)
            p.update_hyperparams(repo_trial.hyperparams)
            repo_trial.set_hyperparams(p.get_hyperparams())

            repo_trial_split: TrialSplit = repo_trial.new_validation_split(
                pipeline=p,
                delete_pipeline_on_completion=delete_pipeline_on_completion)

            with repo_trial_split:
                trial_split_description = _get_trial_split_description(
                    repo_trial=repo_trial,
                    repo_trial_split_number=repo_trial_split.split_number,
                    validation_splits=validation_splits,
                    trial_number=trial_number,
                    n_trial=n_trial)

                context.logger.info(
                    'fitting trial {}'.format(trial_split_description))

                repo_trial_split = self.fit_trial_split(
                    trial_split=repo_trial_split,
                    train_data_container=training_data_container,
                    validation_data_container=validation_data_container,
                    context=context)

                repo_trial_split.set_success()

                context.logger.info(
                    'success trial {}\nbest score: {} at epoch {}'.format(
                        trial_split_description,
                        repo_trial_split.get_best_validation_score(),
                        repo_trial_split.get_n_epochs_to_best_validation_score(
                        )))

        return repo_trial_split