Exemplo n.º 1
0
    def should_resume(self, data_container: DataContainer,
                      context: ExecutionContext) -> bool:
        """
        Returns if the whole data container has been checkpointed.

        :param data_container: data container to read checkpoint for
        :type data_container: neuraxle.data_container.DataContainer
        :param context: execution context to read checkpoint from
        :type context: ExecutionContext
        :return: data container checkpoint
        :rtype: neuraxle.data_container.DataContainer
        """
        if not self.summary_checkpointer.checkpoint_exists(
                context.get_path(), data_container):
            return False

        current_ids = self.summary_checkpointer.read_summary(
            checkpoint_path=context.get_path(), data_container=data_container)

        for current_id in current_ids:
            if not self.data_input_checkpointer.checkpoint_exists(
                    checkpoint_path=self._get_data_input_checkpoint_path(
                        context),
                    current_id=current_id):
                return False

            if not self.expected_output_checkpointer.checkpoint_exists(
                    checkpoint_path=self._get_expected_output_checkpoint_path(
                        context),
                    current_id=current_id):
                return False

        return True
Exemplo n.º 2
0
    def read_checkpoint(self, data_container: DataContainer,
                        context: ExecutionContext) -> DataContainer:
        """
        Read data container data inputs checkpoint with :py:attr:`~data_input_checkpointer`.
        Read data container expected outputs checkpoint with :py:attr:`~expected_output_checkpointer`.

        :param data_container: data container to read checkpoint for
        :type data_container: neuraxle.data_container.DataContainer
        :param context: execution context to read checkpoint from
        :type context: ExecutionContext
        :return: data container checkpoint
        :rtype: neuraxle.data_container.DataContainer
        """
        data_container_checkpoint = ListDataContainer.empty(
            original_data_container=data_container)

        current_ids = self.summary_checkpointer.read_summary(
            checkpoint_path=context.get_path(), data_container=data_container)

        for current_id in current_ids:
            data_input = self.data_input_checkpointer.read_checkpoint(
                checkpoint_path=self._get_data_input_checkpoint_path(context),
                current_id=current_id)

            expected_output = self.expected_output_checkpointer.read_checkpoint(
                checkpoint_path=self._get_expected_output_checkpoint_path(
                    context),
                current_id=current_id)

            data_container_checkpoint.append(current_id, data_input,
                                             expected_output)

        return data_container_checkpoint
Exemplo n.º 3
0
    def save_checkpoint(self, data_container: DataContainer,
                        context: ExecutionContext) -> DataContainer:
        """
        Save data container data inputs with :py:attr:`~data_input_checkpointer`.
        Save data container expected outputs with :py:attr:`~expected_output_checkpointer`.

        :param data_container: data container to checkpoint
        :type data_container: neuraxle.data_container.DataContainer
        :param context: execution context to checkpoint from
        :type context: ExecutionContext
        :return:
        """
        if not self.is_for_execution_mode(context.get_execution_mode()):
            return data_container

        context.mkdir()

        self.summary_checkpointer.save_summary(
            checkpoint_path=context.get_path(), data_container=data_container)

        for current_id, data_input, expected_output in data_container:
            self.data_input_checkpointer.save_checkpoint(
                checkpoint_path=self._get_data_input_checkpoint_path(context),
                current_id=current_id,
                data=data_input)

            self.expected_output_checkpointer.save_checkpoint(
                checkpoint_path=self._get_expected_output_checkpoint_path(
                    context),
                current_id=current_id,
                data=expected_output)

        return data_container
 def _get_saved_model_path(self, context: ExecutionContext, step: BaseStep):
     """
     Returns the saved model path using the given execution context, and step name.
     :param step: step to load
     :type step: BaseStep
     :param context: execution context to load from
     :type context: ExecutionContext
     :return: loaded step
     """
     return os.path.join(context.get_path(),
                         "{0}.ckpt".format(step.get_name()))
Exemplo n.º 5
0
    def _fit_transform_data_container(self, data_container: DataContainer, context: ExecutionContext) -> ('BaseStep', DataContainer):
        """
        Fit transform data container.

        :param context: execution context
        :param data_container: the data container to transform
        :type data_container: neuraxle.data_container.DataContainer

        :return: tuple(fitted pipeline, data_container)
        """
        self.create_checkpoint_path(context.get_path())
        self.flush_cache()

        self.wrapped = self.wrapped.fit(data_container.data_inputs, data_container.expected_outputs)
        outputs = self._transform_with_cache(data_container)
        data_container.set_data_inputs(outputs)

        return self, data_container
Exemplo n.º 6
0
    def handle_transform(self, data_container: DataContainer,
                         context: ExecutionContext) -> DataContainer:
        """
        Transform data container.

        :param context: execution context
        :param data_container: the data container to transform
        :type data_container: neuraxle.data_container.DataContainer

        :return: transformed data container
        """
        self.create_checkpoint_path(context.get_path())
        outputs = self._transform_with_cache(data_container)

        data_container.set_data_inputs(outputs)

        current_ids = self.hash(data_container)
        data_container.set_current_ids(current_ids)

        return data_container