Exemplo n.º 1
0
    def _attempt_trial(self, trial_number, validation_splits,
                       context: ExecutionContext):

        try:
            auto_ml_data = AutoMLContainer(
                trial_number=trial_number,
                trials=self.hyperparams_repository.load_all_trials(
                    TRIAL_STATUS.SUCCESS),
                hyperparameter_space=self.pipeline.get_hyperparams_space(),
                main_scoring_metric_name=self.trainer.get_main_metric_name())

            with self.hyperparams_repository.new_trial(
                    auto_ml_data) as repo_trial:
                repo_trial_split = None
                context.set_logger(repo_trial.logger)
                context.logger.info('trial {}/{}'.format(
                    trial_number + 1, self.n_trial))

                repo_trial_split = self.trainer.execute_trial(
                    pipeline=self.pipeline,
                    context=context,
                    repo_trial=repo_trial,
                    validation_splits=validation_splits,
                    n_trial=self.n_trial)
        except self.error_types_to_raise as error:
            track = traceback.format_exc()
            repo_trial.set_failed(error)
            context.logger.critical(track)
            raise error
        except Exception:
            track = traceback.format_exc()
            repo_trial_split_number = 0 if repo_trial_split is None else repo_trial_split.split_number + 1
            context.logger.error('failed trial {}'.format(
                _get_trial_split_description(
                    repo_trial=repo_trial,
                    repo_trial_split_number=repo_trial_split_number,
                    validation_splits=validation_splits,
                    trial_number=trial_number,
                    n_trial=self.n_trial)))
            context.logger.error(track)
        finally:
            repo_trial.update_final_trial_status()
            # Some heavy objects might have stayed in memory for a while during the execution of our trial;
            # It is best to do a full collection as that may free up some ram.
            gc.collect()
Exemplo n.º 2
0
    def _fit_data_container(self, data_container: DataContainer,
                            context: ExecutionContext) -> 'BaseStep':
        """
        Run Auto ML Loop.
        Find the best hyperparams using the hyperparameter optmizer.
        Evaluate the pipeline on each trial using a validation technique.

        :param data_container: data container to fit
        :param context: execution context

        :return: self
        """
        validation_splits = self.validation_splitter.split_data_container(
            data_container=data_container, context=context)

        # Keeping a reference of the main logger
        main_logger = context.logger

        if self.n_jobs in (-1, None, 1):
            for trial_number in range(self.n_trial):
                self._attempt_trial(trial_number, validation_splits, context)
        else:
            context.logger.info(
                f"Number of processors available: {multiprocessing.cpu_count()}"
            )

            if isinstance(self.hyperparams_repository,
                          InMemoryHyperparamsRepository):
                raise ValueError(
                    "Cannot use InMemoryHyperparamsRepository for multiprocessing, use json-based repository."
                )

            n_jobs = self.n_jobs
            if n_jobs < -1:
                n_jobs = multiprocessing.cpu_count() + 1 + self.n_jobs

            with multiprocessing.get_context("spawn").Pool(
                    processes=n_jobs) as pool:
                args = [(self, trial_number, validation_splits, context)
                        for trial_number in range(self.n_trial)]
                pool.starmap(AutoML._attempt_trial, args)

        context.set_logger(main_logger)

        best_hyperparams = self.hyperparams_repository.get_best_hyperparams()

        context.logger.info('\nbest hyperparams: {}'.format(
            json.dumps(best_hyperparams.to_nested_dict(),
                       sort_keys=True,
                       indent=4)))

        # Notify HyperparamsRepository subscribers
        self.hyperparams_repository.on_complete(
            value=self.hyperparams_repository)

        if self.refit_trial:
            p: BaseStep = self._load_virgin_model(hyperparams=best_hyperparams)
            p = self.trainer.refit(p=p,
                                   data_container=data_container,
                                   context=context.set_execution_phase(
                                       ExecutionPhase.TRAIN))

            self.hyperparams_repository.save_best_model(p)

        return self