def refit(self, p: BaseStep, data_container: DataContainer, context: ExecutionContext) -> BaseStep: """ Refit the pipeline on the whole dataset (without any validation technique). :param p: trial to refit :param data_container: data container :param context: execution context :return: fitted pipeline """ context.set_execution_phase(ExecutionPhase.TRAIN) for i in range(self.epochs): p = p.handle_fit(data_container, context) return p
def test_execswitch(tmpdir): context = ExecutionContext(root=tmpdir, execution_phase=ExecutionPhase.TRAIN) data_inputs = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) phase_to_step = { p: SomeStep() for p in (ExecutionPhase.PRETRAIN, ExecutionPhase.TRAIN, ExecutionPhase.TEST) } p = ExecutionPhaseSwitch(phase_to_step) p_c = p.with_context(context) p_c.fit_transform(data_inputs) assert phase_to_step[ExecutionPhase.PRETRAIN].did_process is False assert phase_to_step[ExecutionPhase.TRAIN].did_process is True assert phase_to_step[ExecutionPhase.TEST].did_process is False p_c = p.with_context( context.set_execution_phase(ExecutionPhase.UNSPECIFIED)) with pytest.raises(KeyError) as error_info: p_c.fit_transform(data_inputs)
def _fit_data_container(self, data_container: DataContainer, context: ExecutionContext) -> 'BaseStep': """ Run Auto ML Loop. Find the best hyperparams using the hyperparameter optmizer. Evaluate the pipeline on each trial using a validation technique. :param data_container: data container to fit :param context: execution context :return: self """ validation_splits = self.validation_splitter.split_data_container( data_container=data_container, context=context) # Keeping a reference of the main logger main_logger = context.logger if self.n_jobs in (-1, None, 1): for trial_number in range(self.n_trial): self._attempt_trial(trial_number, validation_splits, context) else: context.logger.info( f"Number of processors available: {multiprocessing.cpu_count()}" ) if isinstance(self.hyperparams_repository, InMemoryHyperparamsRepository): raise ValueError( "Cannot use InMemoryHyperparamsRepository for multiprocessing, use json-based repository." ) n_jobs = self.n_jobs if n_jobs < -1: n_jobs = multiprocessing.cpu_count() + 1 + self.n_jobs with multiprocessing.get_context("spawn").Pool( processes=n_jobs) as pool: args = [(self, trial_number, validation_splits, context) for trial_number in range(self.n_trial)] pool.starmap(AutoML._attempt_trial, args) context.set_logger(main_logger) best_hyperparams = self.hyperparams_repository.get_best_hyperparams() context.logger.info('\nbest hyperparams: {}'.format( json.dumps(best_hyperparams.to_nested_dict(), sort_keys=True, indent=4))) # Notify HyperparamsRepository subscribers self.hyperparams_repository.on_complete( value=self.hyperparams_repository) if self.refit_trial: p: BaseStep = self._load_virgin_model(hyperparams=best_hyperparams) p = self.trainer.refit(p=p, data_container=data_container, context=context.set_execution_phase( ExecutionPhase.TRAIN)) self.hyperparams_repository.save_best_model(p) return self
def _fit_data_container(self, data_container: DataContainer, context: ExecutionContext) -> 'BaseStep': """ Run Auto ML Loop. Find the best hyperparams using the hyperparameter optmizer. Evaluate the pipeline on each trial using a validation technique. :param data_container: data container to fit :param context: execution context :return: self """ validation_splits = self.validation_splitter.split_data_container( data_container=data_container, context=context) for trial_number in range(self.n_trial): try: auto_ml_data = AutoMLContainer( trial_number=trial_number, trials=self.hyperparams_repository.load_all_trials( TRIAL_STATUS.SUCCESS), hyperparameter_space=self.pipeline.get_hyperparams_space(), main_scoring_metric_name=self.trainer.get_main_metric_name( )) with self.hyperparams_repository.new_trial( auto_ml_data) as repo_trial: repo_trial_split = None self.print_func('\ntrial {}/{}'.format( trial_number + 1, self.n_trial)) repo_trial_split = self.trainer.execute_trial( pipeline=self.pipeline, trial_number=trial_number, repo_trial=repo_trial, context=context, validation_splits=validation_splits, n_trial=self.n_trial) except self.error_types_to_raise as error: track = traceback.format_exc() repo_trial.set_failed(error) self.print_func(track) raise error except Exception: track = traceback.format_exc() repo_trial_split_number = 0 if repo_trial_split is None else repo_trial_split.split_number + 1 self.print_func('failed trial {}'.format( _get_trial_split_description( repo_trial=repo_trial, repo_trial_split_number=repo_trial_split_number, validation_splits=validation_splits, trial_number=trial_number, n_trial=self.n_trial))) self.print_func(track) finally: repo_trial.update_final_trial_status() self._save_trial(repo_trial, trial_number) best_hyperparams = self.hyperparams_repository.get_best_hyperparams() self.print_func('best hyperparams:\n{}'.format( json.dumps(best_hyperparams.to_nested_dict(), sort_keys=True, indent=4))) p: BaseStep = self._load_virgin_model(hyperparams=best_hyperparams) if self.refit_trial: p = self.trainer.refit(p=p, data_container=data_container, context=context.set_execution_phase( ExecutionPhase.TRAIN)) self.hyperparams_repository.save_best_model(p) return self