Beispiel #1
0
    def fit(self,
            X: Dict[str, Any],
            y: Any = None,
            **kwargs: Any) -> autoPyTorchComponent:
        """
        Fits a component by using an input dictionary with pre-requisites

        Args:
            X (X: Dict[str, Any]): Dependencies needed by current component to perform fit
            y (Any): not used. To comply with sklearn API

        Returns:
            A instance of self
        """
        # Make sure that the prerequisites are there
        self.check_requirements(X, y)

        # Setup the logger
        self.logger = get_named_client_logger(
            name=f"{X['num_run']}_{time.time()}",
            # Log to a user provided port else to the default logging port
            port=X['logger_port'] if 'logger_port' in X else
            logging.handlers.DEFAULT_TCP_LOGGING_PORT,
        )

        # Call the actual fit function.
        self._fit(X=X, y=y, **kwargs)

        return cast(autoPyTorchComponent, self.choice)
Beispiel #2
0
 def setup_logger(self, name: str, port: int) -> None:
     self._logger = get_named_client_logger(
         output_dir=self.temporary_directory,
         name=name,
         port=port,
     )
     self.context._logger = self._logger
     return
    def fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> autoPyTorchComponent:
        """
        Fits a component by using an input dictionary with pre-requisites

        Args:
            X (X: Dict[str, Any]): Dependencies needed by current component to perform fit
            y (Any): not used. To comply with sklearn API

        Returns:
            A instance of self
        """
        # Make sure that the prerequisites are there
        self.check_requirements(X, y)

        # Setup the logger
        self.logger = get_named_client_logger(
            output_dir=X['backend'].temporary_directory,
            name=X['job_id'],
            # Log to a user provided port else to the default logging port
            port=X['logger_port'
                   ] if 'logger_port' in X else logging.handlers.DEFAULT_TCP_LOGGING_PORT,
        )

        fit_function = self._fit
        if X['use_pynisher']:
            wall_time_in_s = X['runtime'] if 'runtime' in X else None
            memory_limit = X['cpu_memory_limit'] if 'cpu_memory_limit' in X else None
            fit_function = pynisher.enforce_limits(
                wall_time_in_s=wall_time_in_s,
                mem_in_mb=memory_limit,
                logger=self.logger
            )(self._fit)

        # Call the actual fit function.
        state_dict = fit_function(
            X=X,
            y=y,
            **kwargs
        )

        if X['use_pynisher']:
            # Normally the X[network] is a pointer to the object, so at the
            # end, when we train using X, the pipeline network is updated for free
            # If we do multiprocessing (because of pynisher) we have to update
            # X[network] manually. we do so in a way that every pipeline component
            # can see this new network -- via an update, not overwrite of the pointer
            state_dict = state_dict.result
            X['network'].load_state_dict(state_dict)

        # TODO: when have the optimizer code, the pynisher object might have failed
        # We should process this function as Failure if so trough fit_function.exit_status
        return cast(autoPyTorchComponent, self.choice)
Beispiel #4
0
    def __init__(
        self,
        is_classification: bool = False,
        logger_port: Optional[int] = None,
    ) -> None:
        self.is_classification = is_classification
        self.logger_port = logger_port
        if self.logger_port is not None:
            self.logger: Union[
                logging.Logger,
                PicklableClientLogger] = get_named_client_logger(
                    name='Validation',
                    port=self.logger_port,
                )
        else:
            self.logger = logging.getLogger('Validation')

        self.feature_validator = TabularFeatureValidator(logger=self.logger)
        self.target_validator = TabularTargetValidator(
            is_classification=self.is_classification, logger=self.logger)
        self._is_fitted = False
    def __init__(self,
                 task_type: str,
                 output_type: str,
                 optimize_metric: Optional[str] = None,
                 logger_port: int = logging.handlers.DEFAULT_TCP_LOGGING_PORT,
                 random_state: Optional[np.random.RandomState] = None,
                 name: Optional[str] = None):

        self.model: Optional[Union[CatBoost, BaseEstimator]] = None

        self.name = name if name is not None else self.__class__.__name__
        self.logger_port = logger_port
        self.logger = get_named_client_logger(
            name=self.name,
            host='localhost',
            port=logger_port,
        )

        if random_state is None:
            self.random_state = check_random_state(1)
        else:
            self.random_state = check_random_state(random_state)
        self.config = self.get_config()

        self.all_nan: Optional[np.ndarray] = None
        self.num_classes: Optional[int] = None

        self.is_classification = STRING_TO_TASK_TYPES[
            task_type] not in REGRESSION_TASKS

        self.metric = get_metrics(dataset_properties={
            'task_type': task_type,
            'output_type': output_type
        },
                                  names=[optimize_metric]
                                  if optimize_metric is not None else None)[0]
Beispiel #6
0
    def __init__(
        self,
        config_space: ConfigSpace.ConfigurationSpace,
        dataset_name: str,
        backend: Backend,
        total_walltime_limit: float,
        func_eval_time_limit_secs: float,
        memory_limit: Optional[int],
        metric: autoPyTorchMetric,
        watcher: StopWatch,
        n_jobs: int,
        dask_client: Optional[dask.distributed.Client],
        pipeline_config: Dict[str, Any],
        start_num_run: int = 1,
        seed: int = 1,
        resampling_strategy: Union[
            HoldoutValTypes,
            CrossValTypes] = HoldoutValTypes.holdout_validation,
        resampling_strategy_args: Optional[Dict[str, Any]] = None,
        include: Optional[Dict[str, Any]] = None,
        exclude: Optional[Dict[str, Any]] = None,
        disable_file_output: List = [],
        smac_scenario_args: Optional[Dict[str, Any]] = None,
        get_smac_object_callback: Optional[Callable] = None,
        all_supported_metrics: bool = True,
        ensemble_callback: Optional[EnsembleBuilderManager] = None,
        logger_port: Optional[int] = None,
        search_space_updates: Optional[
            HyperparameterSearchSpaceUpdates] = None,
        portfolio_selection: Optional[str] = None,
        pynisher_context: str = 'spawn',
        min_budget: int = 5,
        max_budget: int = 50,
    ):
        """
        Interface to SMAC. This method calls the SMAC optimize method, and allows
        to pass a callback (ensemble_callback) to make launch task at the end of each
        optimize() algorithm. The later is needed due to the nature of blocking long running
        tasks in Dask.

        Args:
            config_space (ConfigSpace.ConfigurationSpac):
                The configuration space of the whole process
            dataset_name (str):
                The name of the dataset, used to identify the current job
            backend (Backend):
                An interface with disk
            total_walltime_limit (float):
                The maximum allowed time for this job
            func_eval_time_limit_secs (float):
                How much each individual task is allowed to last
            memory_limit (Optional[int]):
                Maximum allowed CPU memory this task can use
            metric (autoPyTorchMetric):
                An scorer object to evaluate the performance of each jon
            watcher (StopWatch):
                A stopwatch object to debug time consumption
            n_jobs (int):
                How many workers are allowed in each task
            dask_client (Optional[dask.distributed.Client]):
                An user provided scheduler. Else smac will create its own.
            start_num_run (int):
                The ID index to start runs
            seed (int):
                To make the run deterministic
            resampling_strategy (str):
                What strategy to use for performance validation
            resampling_strategy_args (Optional[Dict[str, Any]]):
                Arguments to the resampling strategy -- like number of folds
            include (Optional[Dict[str, Any]] = None):
                Optimal Configuration space modifiers
            exclude (Optional[Dict[str, Any]] = None):
                Optimal Configuration space modifiers
            disable_file_output List:
                Support to disable file output to disk -- to reduce space
            smac_scenario_args (Optional[Dict[str, Any]]):
                Additional arguments to the smac scenario
            get_smac_object_callback (Optional[Callable]):
                Allows to create a user specified SMAC object
            pynisher_context (str):
                A string indicating the multiprocessing context to use
            ensemble_callback (Optional[EnsembleBuilderManager]):
                A callback used in this scenario to start ensemble building subtasks
            portfolio_selection (Optional[str]):
                This argument controls the initial configurations that
                AutoPyTorch uses to warm start SMAC for hyperparameter
                optimization. By default, no warm-starting happens.
                The user can provide a path to a json file containing
                configurations, similar to (autoPyTorch/configs/greedy_portfolio.json).
                Additionally, the keyword 'greedy' is supported,
                which would use the default portfolio from
                `AutoPyTorch Tabular <https://arxiv.org/abs/2006.13799>_`
            min_budget (int):
                Auto-PyTorch uses `Hyperband <https://arxiv.org/abs/1603.06560>_` to
                trade-off resources between running many pipelines at min_budget and
                running the top performing pipelines on max_budget.
                min_budget states the minimum resource allocation a pipeline should have
                so that we can compare and quickly discard bad performing models.
                For example, if the budget_type is epochs, and min_budget=5, then we will
                run every pipeline to a minimum of 5 epochs before performance comparison.
            max_budget (int):
                Auto-PyTorch uses `Hyperband <https://arxiv.org/abs/1603.06560>_` to
                trade-off resources between running many pipelines at min_budget and
                running the top performing pipelines on max_budget.
                max_budget states the maximum resource allocation a pipeline is going to
                be ran. For example, if the budget_type is epochs, and max_budget=50,
                then the pipeline training will be terminated after 50 epochs.
        """
        super(AutoMLSMBO, self).__init__()
        # data related
        self.dataset_name = dataset_name
        self.datamanager: Optional[BaseDataset] = None
        self.metric = metric
        self.task: Optional[str] = None
        self.backend = backend
        self.all_supported_metrics = all_supported_metrics

        self.pipeline_config = pipeline_config
        # the configuration space
        self.config_space = config_space

        # the number of parallel workers/jobs
        self.n_jobs = n_jobs
        self.dask_client = dask_client

        # Evaluation
        self.resampling_strategy = resampling_strategy
        if resampling_strategy_args is None:
            resampling_strategy_args = DEFAULT_RESAMPLING_PARAMETERS[
                resampling_strategy]
        self.resampling_strategy_args = resampling_strategy_args

        # and a bunch of useful limits
        self.worst_possible_result = get_cost_of_crash(self.metric)
        self.total_walltime_limit = int(total_walltime_limit)
        self.func_eval_time_limit_secs = int(func_eval_time_limit_secs)
        self.memory_limit = memory_limit
        self.watcher = watcher
        self.seed = seed
        self.start_num_run = start_num_run
        self.include = include
        self.exclude = exclude
        self.disable_file_output = disable_file_output
        self.smac_scenario_args = smac_scenario_args
        self.get_smac_object_callback = get_smac_object_callback
        self.pynisher_context = pynisher_context
        self.min_budget = min_budget
        self.max_budget = max_budget

        self.ensemble_callback = ensemble_callback

        self.search_space_updates = search_space_updates

        if logger_port is None:
            self.logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT
        else:
            self.logger_port = logger_port
        logger_name = '%s(%d):%s' % (self.__class__.__name__, self.seed,
                                     ":" + self.dataset_name)
        self.logger = get_named_client_logger(name=logger_name,
                                              port=self.logger_port)
        self.logger.info("initialised {}".format(self.__class__.__name__))

        self.initial_configurations: Optional[List[Configuration]] = None
        if portfolio_selection is not None:
            self.initial_configurations = read_return_initial_configurations(
                config_space=config_space,
                portfolio_selection=portfolio_selection)
Beispiel #7
0
    def __init__(
        self,
        backend: Backend,
        queue: Queue,
        metric: autoPyTorchMetric,
        budget: float,
        configuration: Union[int, str, Configuration],
        budget_type: str = None,
        pipeline_config: Optional[Dict[str, Any]] = None,
        seed: int = 1,
        output_y_hat_optimization: bool = True,
        num_run: Optional[int] = None,
        include: Optional[Dict[str, Any]] = None,
        exclude: Optional[Dict[str, Any]] = None,
        disable_file_output: Union[bool, List[str]] = False,
        init_params: Optional[Dict[str, Any]] = None,
        logger_port: Optional[int] = None,
        all_supported_metrics: bool = True,
        search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
    ) -> None:

        self.starttime = time.time()

        self.configuration = configuration
        self.backend: Backend = backend
        self.queue = queue

        self.datamanager: BaseDataset = self.backend.load_datamanager()

        assert self.datamanager.task_type is not None, \
            "Expected dataset {} to have task_type got None".format(self.datamanager.__class__.__name__)
        self.task_type = STRING_TO_TASK_TYPES[self.datamanager.task_type]
        self.output_type = STRING_TO_OUTPUT_TYPES[self.datamanager.output_type]
        self.issparse = self.datamanager.issparse

        self.include = include
        self.exclude = exclude
        self.search_space_updates = search_space_updates

        self.X_train, self.y_train = self.datamanager.train_tensors

        if self.datamanager.val_tensors is not None:
            self.X_valid, self.y_valid = self.datamanager.val_tensors
        else:
            self.X_valid, self.y_valid = None, None

        if self.datamanager.test_tensors is not None:
            self.X_test, self.y_test = self.datamanager.test_tensors
        else:
            self.X_test, self.y_test = None, None

        self.metric = metric

        self.seed = seed

        # Flag to save target for ensemble
        self.output_y_hat_optimization = output_y_hat_optimization

        if isinstance(disable_file_output, bool):
            self.disable_file_output: bool = disable_file_output
        elif isinstance(disable_file_output, List):
            self.disabled_file_outputs: List[str] = disable_file_output
        else:
            raise ValueError(
                'disable_file_output should be either a bool or a list')

        self.pipeline_class: Optional[Union[BaseEstimator,
                                            BasePipeline]] = None
        if self.task_type in REGRESSION_TASKS:
            if isinstance(self.configuration, int):
                self.pipeline_class = DummyRegressionPipeline
            elif isinstance(self.configuration, str):
                self.pipeline_class = MyTraditionalTabularRegressionPipeline
            elif isinstance(self.configuration, Configuration):
                self.pipeline_class = autoPyTorch.pipeline.tabular_regression.TabularRegressionPipeline
            else:
                raise ValueError('task {} not available'.format(
                    self.task_type))
            self.predict_function = self._predict_regression
        else:
            if isinstance(self.configuration, int):
                self.pipeline_class = DummyClassificationPipeline
            elif isinstance(self.configuration, str):
                if self.task_type in TABULAR_TASKS:
                    self.pipeline_class = MyTraditionalTabularClassificationPipeline
                else:
                    raise ValueError(
                        "Only tabular tasks are currently supported with traditional methods"
                    )
            elif isinstance(self.configuration, Configuration):
                if self.task_type in TABULAR_TASKS:
                    self.pipeline_class = autoPyTorch.pipeline.tabular_classification.TabularClassificationPipeline
                elif self.task_type in IMAGE_TASKS:
                    self.pipeline_class = autoPyTorch.pipeline.image_classification.ImageClassificationPipeline
                else:
                    raise ValueError('task {} not available'.format(
                        self.task_type))
            self.predict_function = self._predict_proba
        self.dataset_properties = self.datamanager.get_dataset_properties(
            get_dataset_requirements(
                info=self.datamanager.get_required_dataset_info(),
                include=self.include,
                exclude=self.exclude,
                search_space_updates=self.search_space_updates))

        self.additional_metrics: Optional[List[autoPyTorchMetric]] = None
        metrics_dict: Optional[Dict[str, List[str]]] = None
        if all_supported_metrics:
            self.additional_metrics = get_metrics(
                dataset_properties=self.dataset_properties,
                all_supported_metrics=all_supported_metrics)
            # Update fit dictionary with metrics passed to the evaluator
            metrics_dict = {'additional_metrics': []}
            metrics_dict['additional_metrics'].append(self.metric.name)
            for metric in self.additional_metrics:
                metrics_dict['additional_metrics'].append(metric.name)

        self._init_params = init_params

        assert self.pipeline_class is not None, "Could not infer pipeline class"
        pipeline_config = pipeline_config if pipeline_config is not None \
            else self.pipeline_class.get_default_pipeline_options()
        self.budget_type = pipeline_config[
            'budget_type'] if budget_type is None else budget_type
        self.budget = pipeline_config[
            self.budget_type] if budget == 0 else budget

        self.num_run = 0 if num_run is None else num_run

        logger_name = '%s(%d)' % (self.__class__.__name__.split('.')[-1],
                                  self.seed)
        if logger_port is None:
            logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT
        self.logger = get_named_client_logger(
            name=logger_name,
            port=logger_port,
        )

        self._init_fit_dictionary(logger_port=logger_port,
                                  pipeline_config=pipeline_config,
                                  metrics_dict=metrics_dict)
        self.Y_optimization: Optional[np.ndarray] = None
        self.Y_actual_train: Optional[np.ndarray] = None
        self.pipelines: Optional[List[BaseEstimator]] = None
        self.pipeline: Optional[BaseEstimator] = None
        self.logger.debug("Fit dictionary in Abstract evaluator: {}".format(
            dict_repr(self.fit_dictionary)))
        self.logger.debug("Search space updates :{}".format(
            self.search_space_updates))
Beispiel #8
0
    def run(
        self,
        config: Configuration,
        instance: Optional[str] = None,
        cutoff: Optional[float] = None,
        seed: int = 12345,
        budget: float = 0.0,
        instance_specific: Optional[str] = None,
    ) -> Tuple[StatusType, float, float, Dict[str, Any]]:

        context = multiprocessing.get_context(self.pynisher_context)
        preload_modules(context)
        queue: multiprocessing.queues.Queue = context.Queue()

        if not (instance_specific is None or instance_specific == '0'):
            raise ValueError(instance_specific)
        init_params = {'instance': instance}
        if self.init_params is not None:
            init_params.update(self.init_params)

        if self.logger_port is None:
            logger: Union[logging.Logger,
                          PicklableClientLogger] = logging.getLogger(
                              "pynisher")
        else:
            logger = get_named_client_logger(
                name="pynisher",
                port=self.logger_port,
            )

        pynisher_arguments = dict(
            logger=logger,
            # Pynisher expects seconds as a time indicator
            wall_time_in_s=int(cutoff) if cutoff is not None else None,
            mem_in_mb=self.memory_limit,
            capture_output=True,
            context=context,
        )

        if isinstance(config, (int, str)):
            num_run = self.initial_num_run
        else:
            num_run = config.config_id + self.initial_num_run

        self.logger.debug("Search space updates for {}: {}".format(
            num_run, self.search_space_updates))
        obj_kwargs = dict(
            queue=queue,
            config=config,
            backend=self.backend,
            metric=self.metric,
            seed=self.seed,
            num_run=num_run,
            output_y_hat_optimization=self.output_y_hat_optimization,
            include=self.include,
            exclude=self.exclude,
            disable_file_output=self.disable_file_output,
            instance=instance,
            init_params=init_params,
            budget=budget,
            budget_type=self.budget_type,
            pipeline_config=self.pipeline_config,
            logger_port=self.logger_port,
            all_supported_metrics=self.all_supported_metrics,
            search_space_updates=self.search_space_updates)

        info: Optional[List[RunValue]]
        additional_run_info: Dict[str, Any]
        try:
            obj = pynisher.enforce_limits(**pynisher_arguments)(self.ta)
            obj(**obj_kwargs)
        except Exception as e:
            exception_traceback = traceback.format_exc()
            error_message = repr(e)
            additional_run_info = {
                'traceback': exception_traceback,
                'error': error_message
            }
            return StatusType.CRASHED, self.cost_for_crash, 0.0, additional_run_info

        if obj.exit_status in (pynisher.TimeoutException,
                               pynisher.MemorylimitException):
            # Even if the pynisher thinks that a timeout or memout occured,
            # it can be that the target algorithm wrote something into the queue
            #  - then we treat it as a successful run
            try:
                info = read_queue(queue)  # type: ignore
                result = info[-1]['loss']  # type: ignore
                status = info[-1]['status']  # type: ignore
                additional_run_info = info[-1][
                    'additional_run_info']  # type: ignore

                if obj.stdout:
                    additional_run_info['subprocess_stdout'] = obj.stdout
                if obj.stderr:
                    additional_run_info['subprocess_stderr'] = obj.stderr

                if obj.exit_status is pynisher.TimeoutException:
                    additional_run_info[
                        'info'] = 'Run stopped because of timeout.'
                elif obj.exit_status is pynisher.MemorylimitException:
                    additional_run_info[
                        'info'] = 'Run stopped because of memout.'

                if status in [StatusType.SUCCESS, StatusType.DONOTADVANCE]:
                    cost = result
                else:
                    cost = self.worst_possible_result

            except Empty:
                info = None
                if obj.exit_status is pynisher.TimeoutException:
                    status = StatusType.TIMEOUT
                    additional_run_info = {'error': 'Timeout'}
                elif obj.exit_status is pynisher.MemorylimitException:
                    status = StatusType.MEMOUT
                    additional_run_info = {
                        'error':
                        'Memout (used more than {} MB).'.format(
                            self.memory_limit)
                    }
                else:
                    raise ValueError(obj.exit_status)
                cost = self.worst_possible_result

        elif obj.exit_status is TAEAbortException:
            info = None
            status = StatusType.ABORT
            cost = self.worst_possible_result
            additional_run_info = {
                'error': 'Your configuration of '
                'autoPyTorch does not work!',
                'exit_status': _encode_exit_status(obj.exit_status),
                'subprocess_stdout': obj.stdout,
                'subprocess_stderr': obj.stderr,
            }

        else:
            try:
                info = read_queue(queue)  # type: ignore
                result = info[-1]['loss']  # type: ignore
                status = info[-1]['status']  # type: ignore
                additional_run_info = info[-1][
                    'additional_run_info']  # type: ignore

                if obj.exit_status == 0:
                    cost = result
                else:
                    status = StatusType.CRASHED
                    cost = self.worst_possible_result
                    additional_run_info['info'] = 'Run treated as crashed ' \
                                                  'because the pynisher exit ' \
                                                  'status %s is unknown.' % \
                                                  str(obj.exit_status)
                    additional_run_info['exit_status'] = _encode_exit_status(
                        obj.exit_status)
                    additional_run_info['subprocess_stdout'] = obj.stdout
                    additional_run_info['subprocess_stderr'] = obj.stderr
            except Empty:
                info = None
                additional_run_info = {
                    'error': 'Result queue is empty',
                    'exit_status': _encode_exit_status(obj.exit_status),
                    'subprocess_stdout': obj.stdout,
                    'subprocess_stderr': obj.stderr,
                    'exitcode': obj.exitcode
                }
                status = StatusType.CRASHED
                cost = self.worst_possible_result

        if ((self.budget_type is None or budget == 0)
                and status == StatusType.DONOTADVANCE):
            status = StatusType.SUCCESS

        if not isinstance(additional_run_info, dict):
            additional_run_info = {'message': additional_run_info}

        if (info is not None and self.resampling_strategy
                in ['holdout-iterative-fit', 'cv-iterative-fit']
                and status != StatusType.CRASHED):
            learning_curve = extract_learning_curve(info)
            learning_curve_runtime = extract_learning_curve(info, 'duration')
            if len(learning_curve) > 1:
                additional_run_info['learning_curve'] = learning_curve
                additional_run_info[
                    'learning_curve_runtime'] = learning_curve_runtime

            train_learning_curve = extract_learning_curve(info, 'train_loss')
            if len(train_learning_curve) > 1:
                additional_run_info[
                    'train_learning_curve'] = train_learning_curve
                additional_run_info[
                    'learning_curve_runtime'] = learning_curve_runtime

            if self._get_validation_loss:
                validation_learning_curve = extract_learning_curve(
                    info, 'validation_loss')
                if len(validation_learning_curve) > 1:
                    additional_run_info['validation_learning_curve'] = \
                        validation_learning_curve
                    additional_run_info[
                        'learning_curve_runtime'] = learning_curve_runtime

            if self._get_test_loss:
                test_learning_curve = extract_learning_curve(info, 'test_loss')
                if len(test_learning_curve) > 1:
                    additional_run_info[
                        'test_learning_curve'] = test_learning_curve
                    additional_run_info[
                        'learning_curve_runtime'] = learning_curve_runtime

        if isinstance(config, int):
            origin = 'DUMMY'
        elif isinstance(config, str):
            origin = 'traditional'
        else:
            origin = getattr(config, 'origin', 'UNKNOWN')
        additional_run_info['configuration_origin'] = origin

        runtime = float(obj.wall_clock_time)

        empty_queue(queue)
        self.logger.debug("Finish function evaluation {}.\n"
                          "Status: {}, Cost: {}, Runtime: {},\n"
                          "Additional information:\n{}".format(
                              str(num_run), status, cost, runtime,
                              dict_repr(additional_run_info)))
        return status, cost, runtime, additional_run_info
Beispiel #9
0
    def __init__(
        self,
        backend: Backend,
        seed: int,
        metric: autoPyTorchMetric,
        cost_for_crash: float,
        abort_on_first_run_crash: bool,
        pynisher_context: str,
        pipeline_config: Optional[Dict[str, Any]] = None,
        initial_num_run: int = 1,
        stats: Optional[Stats] = None,
        run_obj: str = 'quality',
        par_factor: int = 1,
        output_y_hat_optimization: bool = True,
        include: Optional[Dict[str, Any]] = None,
        exclude: Optional[Dict[str, Any]] = None,
        memory_limit: Optional[int] = None,
        disable_file_output: bool = False,
        init_params: Dict[str, Any] = None,
        budget_type: str = None,
        ta: Optional[Callable] = None,
        logger_port: int = None,
        all_supported_metrics: bool = True,
        search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
    ):

        eval_function = autoPyTorch.evaluation.train_evaluator.eval_function

        self.worst_possible_result = cost_for_crash

        eval_function = functools.partial(
            fit_predict_try_except_decorator,
            ta=eval_function,
            cost_for_crash=self.worst_possible_result,
        )

        super().__init__(
            ta=ta if ta is not None else eval_function,
            stats=stats,
            run_obj=run_obj,
            par_factor=par_factor,
            cost_for_crash=self.worst_possible_result,
            abort_on_first_run_crash=abort_on_first_run_crash,
        )

        self.backend = backend
        self.pynisher_context = pynisher_context
        self.seed = seed
        self.initial_num_run = initial_num_run
        self.metric = metric
        self.output_y_hat_optimization = output_y_hat_optimization
        self.include = include
        self.exclude = exclude
        self.disable_file_output = disable_file_output
        self.init_params = init_params

        self.budget_type = pipeline_config[
            'budget_type'] if pipeline_config is not None else budget_type

        self.pipeline_config: Dict[str, Union[int, str, float]] = dict()
        if pipeline_config is None:
            pipeline_config = replace_string_bool_to_bool(
                json.load(
                    open(
                        os.path.join(
                            os.path.dirname(__file__),
                            '../configs/default_pipeline_options.json'))))
        self.pipeline_config.update(pipeline_config)

        self.logger_port = logger_port
        if self.logger_port is None:
            self.logger: Union[logging.Logger,
                               PicklableClientLogger] = logging.getLogger(
                                   "TAE")
        else:
            self.logger = get_named_client_logger(
                name="TAE",
                port=self.logger_port,
            )
        self.all_supported_metrics = all_supported_metrics

        if memory_limit is not None:
            memory_limit = int(math.ceil(memory_limit))
        self.memory_limit = memory_limit

        dm = self.backend.load_datamanager()
        if dm.val_tensors is not None:
            self._get_validation_loss = True
        else:
            self._get_validation_loss = False
        if dm.test_tensors is not None:
            self._get_test_loss = True
        else:
            self._get_test_loss = False

        self.resampling_strategy = dm.resampling_strategy
        self.resampling_strategy_args = dm.resampling_strategy_args

        self.search_space_updates = search_space_updates