示例#1
0
    def __init__(
            self,
            epochs: int,
            scoring_callback: ScoringCallback,
            validation_splitter: 'BaseValidationSplitter',
            callbacks: List[BaseCallback] = None,
            print_func: Callable = None,
            hyperparams_repository: HyperparamsRepository = None
    ):
        self.epochs: int = epochs
        self.validation_split_function = validation_splitter

        if callbacks is None:
            callbacks = []
        callbacks: List[BaseCallback] = [scoring_callback] + callbacks
        self.callbacks: CallbackList = CallbackList(callbacks)

        if print_func is None:
            print_func = print

        if hyperparams_repository is None:
            hyperparams_repository = InMemoryHyperparamsRepository()
        self.hyperparams_repository: HyperparamsRepository = hyperparams_repository

        self.print_func = print_func
示例#2
0
    def __init__(self,
                 epochs,
                 metrics=None,
                 callbacks=None,
                 print_metrics=True,
                 print_func=None):
        self.epochs = epochs
        if metrics is None:
            metrics = {}
        self.metrics = metrics
        self._initialize_metrics(metrics)

        self.callbacks = CallbackList(callbacks)

        if print_func is None:
            print_func = print

        self.print_func = print_func
        self.print_metrics = print_metrics
示例#3
0
    def __init__(
            self,
            epochs: int,
            scoring_callback: ScoringCallback,
            validation_splitter: 'BaseValidationSplitter',
            callbacks: List[BaseCallback] = None,
            print_func: Callable = None
    ):
        self.epochs: int = epochs
        self.validation_split_function = validation_splitter

        if callbacks is None:
            callbacks = []
        callbacks: List[BaseCallback] = [scoring_callback] + callbacks
        self.callbacks: CallbackList = CallbackList(callbacks)

        if print_func is None:
            print_func = print

        self.print_func = print_func
示例#4
0
class Trainer:
    """

    Example usage :

    .. code-block:: python

        trainer = Trainer(
            callbacks=[],
            epochs=10,
            print_func=print
        )

        repo_trial = trainer.fit(
            p=p,
            trial_repository=repo_trial,
            train_data_container=training_data_container,
            validation_data_container=validation_data_container,
            context=context
        )

        pipeline = trainer.refit(repo_trial.pipeline, data_container, context)


    .. seealso::
        :class:`AutoML`,
        :class:`Trainer`,
        :class:`~neuraxle.metaopt.trial.Trial`,
        :class:`InMemoryHyperparamsRepository`,
        :class:`HyperparamsJSONRepository`,
        :class:`BaseHyperparameterSelectionStrategy`,
        :class:`RandomSearchHyperparameterSelectionStrategy`,
        :class:`~neuraxle.hyperparams.space.HyperparameterSamples`
    """
    def __init__(self,
                 epochs,
                 metrics=None,
                 callbacks=None,
                 print_metrics=True,
                 print_func=None):
        self.epochs = epochs
        if metrics is None:
            metrics = {}
        self.metrics = metrics
        self._initialize_metrics(metrics)

        self.callbacks = CallbackList(callbacks)

        if print_func is None:
            print_func = print

        self.print_func = print_func
        self.print_metrics = print_metrics

    def fit_trial_split(self, trial_split: TrialSplit,
                        train_data_container: DataContainer,
                        validation_data_container: DataContainer,
                        context: ExecutionContext) -> TrialSplit:
        """
        Train pipeline using the training data container.
        Track training, and validation metrics for each epoch.

        :param train_data_container: train data container
        :param validation_data_container: validation data container
        :param trial_split: trial to execute
        :param context: execution context

        :return: executed trial
        """
        early_stopping = False

        for i in range(self.epochs):
            self.print_func('\nepoch {}/{}'.format(i + 1, self.epochs))
            trial_split = trial_split.fit_trial_split(train_data_container,
                                                      context)
            y_pred_train = trial_split.predict_with_pipeline(
                train_data_container, context)
            y_pred_val = trial_split.predict_with_pipeline(
                validation_data_container, context)

            if self.callbacks.call(trial=trial_split,
                                   epoch_number=i,
                                   total_epochs=self.epochs,
                                   input_train=train_data_container,
                                   pred_train=y_pred_train,
                                   input_val=validation_data_container,
                                   pred_val=y_pred_val,
                                   is_finished_and_fitted=early_stopping):
                break

        return trial_split

    def refit(self, p: BaseStep, data_container: DataContainer,
              context: ExecutionContext) -> BaseStep:
        """
        Refit the pipeline on the whole dataset (without any validation technique).

        :param p: trial to refit
        :param data_container: data container
        :param context: execution context

        :return: fitted pipeline
        """
        for i in range(self.epochs):
            p = p.handle_fit(data_container, context)

        return p

    def _initialize_metrics(self, metrics):
        """
        Initialize metrics results dict for train, and validation using the metrics function dict.

        :param metrics: metrics function dict

        :return:
        """
        self.metrics_results_train = {}
        self.metrics_results_validation = {}

        for m in metrics:
            self.metrics_results_train[m] = []
            self.metrics_results_validation[m] = []

    def get_main_metric_name(self) -> str:
        """
        Get main metric name.

        :return:
        """
        return self.callbacks[0].name