Exemplo n.º 1
0
    def save_checkpoint(self, data_container: DataContainer,
                        context: ExecutionContext) -> DataContainer:
        if self.is_for_execution_mode(context.get_execution_mode()):
            # TODO: save the context by execution mode AND data container ids / summary
            context.copy().save()

        return data_container
Exemplo n.º 2
0
    def _load_checkpoint(
            self, data_container: DataContainer,
            context: ExecutionContext) -> Tuple[NamedTupleList, DataContainer]:
        """
        Try loading a pipeline cache with the passed data container.
        If pipeline cache loading succeeds, find steps left to do,
        and load the latest data container.

        :param data_container: the data container to resume
        :param context: the execution context to resume
        :return: tuple(steps left to do, last checkpoint data container)
        """
        new_starting_step_index, starting_step_data_container = \
            self._get_starting_step_info(data_container, context)

        loading_context = context.copy()
        loading_context.pop()
        loaded_pipeline = self.load(loading_context)

        if not self.are_steps_before_index_the_same(loaded_pipeline,
                                                    new_starting_step_index):
            return self.steps_as_tuple, data_container

        self._assign_loaded_pipeline_into_self(loaded_pipeline)

        step = self[new_starting_step_index]
        if isinstance(step, Checkpoint) or (isinstance(
                step, MetaStep) and isinstance(step.wrapped, Checkpoint)):
            starting_step_data_container = step.resume(
                starting_step_data_container, context)

        return self[new_starting_step_index:], starting_step_data_container
Exemplo n.º 3
0
    def fit_trial_split(self, trial_split: TrialSplit,
                        train_data_container: DataContainer,
                        validation_data_container: DataContainer,
                        context: ExecutionContext) -> TrialSplit:
        """
        Train pipeline using the training data container.
        Track training, and validation metrics for each epoch.

        :param train_data_container: train data container
        :param validation_data_container: validation data container
        :param trial_split: trial to execute
        :param context: execution context

        :return: executed trial
        """

        for i in range(self.epochs):
            context.logger.info('epoch {}/{}'.format(i + 1, self.epochs))
            trial_split = trial_split.fit_trial_split(
                train_data_container.copy(),
                context.copy().set_execution_phase(ExecutionPhase.TRAIN))
            y_pred_train = trial_split.predict_with_pipeline(
                train_data_container.copy(),
                context.copy().set_execution_phase(ExecutionPhase.VALIDATION))
            y_pred_val = trial_split.predict_with_pipeline(
                validation_data_container.copy(),
                context.copy().set_execution_phase(ExecutionPhase.VALIDATION))

            if self.callbacks.call(
                    trial_split=trial_split,
                    epoch_number=i,
                    total_epochs=self.epochs,
                    input_train=train_data_container,
                    pred_train=y_pred_train,
                    input_val=validation_data_container,
                    pred_val=y_pred_val,
                    context=context.copy().set_execution_phase(
                        ExecutionPhase.VALIDATION),
                    is_finished_and_fitted=False,
            ):
                break
            # Saves the metrics
            trial_split.save_parent_trial()

        return trial_split