Beispiel #1
0
 def __init__(self,
              checkpoint_params,
              dataset,
              validation_dataset=None,
              txt_preproc=None,
              txt_postproc=None,
              data_preproc=None,
              data_augmenter=None,
              n_augmentations=0,
              weights=None,
              codec=None,
              codec_whitelist=[]):
     self.checkpoint_params = checkpoint_params
     self.dataset = dataset
     self.validation_dataset = validation_dataset
     self.data_augmenter = data_augmenter
     self.n_augmentations = n_augmentations
     self.txt_preproc = txt_preproc if txt_preproc else text_processor_from_proto(
         checkpoint_params.model.text_preprocessor, "pre")
     self.txt_postproc = txt_postproc if txt_postproc else text_processor_from_proto(
         checkpoint_params.model.text_postprocessor, "post")
     self.data_preproc = data_preproc if data_preproc else data_processor_from_proto(
         checkpoint_params.model.data_preprocessor)
     self.weights = checkpoint_path(weights) if weights else None
     self.codec = codec
     self.codec_whitelist = codec_whitelist
Beispiel #2
0
    def __init__(self, params: TrainerParams, scenario, restore=False):
        """Train a DNN using given preprocessing, weights, and data

        The purpose of the Trainer is handle a default training mechanism.
        As required input it expects a `dataset` and hyperparameters (`checkpoint_params`).

        The steps are
            1. Loading and preprocessing of the dataset
            2. Computation of the codec
            3. Construction of the DNN in the desired Deep Learning Framework
            4. Launch of the training

        During the training the Trainer will perform validation checks if a `validation_dataset` is given
        to determine the best model.
        Furthermore, the current status is printet and checkpoints are written.
        """
        super(Trainer, self).__init__(params, scenario, restore)
        self._params: TrainerParams = params
        if not isinstance(self._params.checkpoint_save_freq,
                          str) and self._params.checkpoint_save_freq < 0:
            self._params.checkpoint_save_freq = self._params.early_stopping_params.frequency
        self._params.warmstart.model = (checkpoint_path(
            self._params.warmstart.model) if self._params.warmstart.model else
                                        None)
        self.checkpoint = None
        if self._params.warmstart.model:
            # Manually handle loading
            self.checkpoint = SavedCalamariModel(
                self._params.warmstart.model,
                auto_update=self._params.auto_upgrade_checkpoints,
            )
            self._params.warmstart.model = self.checkpoint.ckpt_path + ".h5"
            self._params.warmstart.trim_graph_name = False

        self._codec_changes = None
Beispiel #3
0
    def __init__(
        self,
        checkpoint_params,
        dataset,
        validation_dataset=None,
        txt_preproc=None,
        txt_postproc=None,
        data_preproc=None,
        data_augmenter: DataAugmenter = None,
        n_augmentations=0,
        weights=None,
        codec=None,
        codec_whitelist=None,
        auto_update_checkpoints=True,
        preload_training=False,
        preload_validation=False,
    ):
        """Train a DNN using given preprocessing, weights, and data

        The purpose of the Trainer is handle a default training mechanism.
        As required input it expects a `dataset` and hyperparameters (`checkpoint_params`).

        The steps are
            1. Loading and preprocessing of the dataset
            2. Computation of the codec
            3. Construction of the DNN in the desired Deep Learning Framework
            4. Launch of the training

        During the training the Trainer will perform validation checks if a `validation_dataset` is given
        to determine the best model.
        Furthermore, the current status is printet and checkpoints are written.

        Parameters
        ----------
        checkpoint_params : CheckpointParams
            Proto parameter object that defines all hyperparameters of the model
        dataset : Dataset
            The Dataset used for training
        validation_dataset : Dataset, optional
            The Dataset used for validation, i.e. choosing the best model
        txt_preproc : TextProcessor, optional
            Text preprocessor that is applied on loaded text, before the Codec is computed
        txt_postproc : TextProcessor, optional
            Text processor that is applied on the loaded GT text and on the prediction to receive the final result
        data_preproc : DataProcessor, optional
            Preprocessing for the image lines (e. g. padding, inversion, deskewing, ...)
        data_augmenter : DataAugmenter, optional
            A DataAugmenter object to use for data augmentation. Count is set by `n_augmentations`
        n_augmentations : int, optional
            The number of augmentations performend by the `data_augmenter`
        weights : str, optional
            Path to a trained model for loading its weights
        codec : Codec, optional
            If provided the Codec will not be computed automaticall based on the GT, but instead `codec` will be used
        codec_whitelist : obj:`list` of :obj:`str`
            List of characters to be kept when the loaded `weights` have a different codec than the new one.
        """
        self.checkpoint_params = checkpoint_params
        self.txt_preproc = txt_preproc if txt_preproc else text_processor_from_proto(
            checkpoint_params.model.text_preprocessor, "pre")
        self.txt_postproc = txt_postproc if txt_postproc else text_processor_from_proto(
            checkpoint_params.model.text_postprocessor, "post")
        self.data_preproc = data_preproc if data_preproc else data_processor_from_proto(
            checkpoint_params.model.data_preprocessor)
        self.weights = checkpoint_path(weights) if weights else None
        self.codec = codec
        self.codec_whitelist = [] if codec_whitelist is None else codec_whitelist
        self.auto_update_checkpoints = auto_update_checkpoints
        self.dataset = InputDataset(dataset, self.data_preproc,
                                    self.txt_preproc, data_augmenter,
                                    n_augmentations)
        self.validation_dataset = InputDataset(
            validation_dataset, self.data_preproc,
            self.txt_preproc) if validation_dataset else None
        self.preload_training = preload_training
        self.preload_validation = preload_validation

        if len(self.dataset) == 0:
            raise Exception("Dataset is empty.")

        if self.validation_dataset and len(self.validation_dataset) == 0:
            raise Exception(
                "Validation dataset is empty. Provide valid validation data for early stopping."
            )
Beispiel #4
0
    def __init__(self, checkpoint_params,
                 dataset,
                 validation_dataset=None,
                 txt_preproc=None,
                 txt_postproc=None,
                 data_preproc=None,
                 data_augmenter: DataAugmenter = None,
                 n_augmentations=0,
                 weights=None,
                 codec=None,
                 codec_whitelist=[],
                 auto_update_checkpoints=True,
                 preload_training=False,
                 preload_validation=False,
                 ):
        """Train a DNN using given preprocessing, weights, and data

        The purpose of the Trainer is handle a default training mechanism.
        As required input it expects a `dataset` and hyperparameters (`checkpoint_params`).

        The steps are
            1. Loading and preprocessing of the dataset
            2. Computation of the codec
            3. Construction of the DNN in the desired Deep Learning Framework
            4. Launch of the training

        During the training the Trainer will perform validation checks if a `validation_dataset` is given
        to determine the best model.
        Furthermore, the current status is printet and checkpoints are written.

        Parameters
        ----------
        checkpoint_params : CheckpointParams
            Proto parameter object that defines all hyperparameters of the model
        dataset : Dataset
            The Dataset used for training
        validation_dataset : Dataset, optional
            The Dataset used for validation, i.e. choosing the best model
        txt_preproc : TextProcessor, optional
            Text preprocessor that is applied on loaded text, before the Codec is computed
        txt_postproc : TextProcessor, optional
            Text processor that is applied on the loaded GT text and on the prediction to receive the final result
        data_preproc : DataProcessor, optional
            Preprocessing for the image lines (e. g. padding, inversion, deskewing, ...)
        data_augmenter : DataAugmenter, optional
            A DataAugmenter object to use for data augmentation. Count is set by `n_augmentations`
        n_augmentations : int, optional
            The number of augmentations performend by the `data_augmenter`
        weights : str, optional
            Path to a trained model for loading its weights
        codec : Codec, optional
            If provided the Codec will not be computed automaticall based on the GT, but instead `codec` will be used
        codec_whitelist : obj:`list` of :obj:`str`
            List of characters to be kept when the loaded `weights` have a different codec than the new one.
        """
        self.checkpoint_params = checkpoint_params
        self.txt_preproc = txt_preproc if txt_preproc else text_processor_from_proto(checkpoint_params.model.text_preprocessor, "pre")
        self.txt_postproc = txt_postproc if txt_postproc else text_processor_from_proto(checkpoint_params.model.text_postprocessor, "post")
        self.data_preproc = data_preproc if data_preproc else data_processor_from_proto(checkpoint_params.model.data_preprocessor)
        self.weights = checkpoint_path(weights) if weights else None
        self.codec = codec
        self.codec_whitelist = codec_whitelist
        self.auto_update_checkpoints = auto_update_checkpoints
        self.dataset = InputDataset(dataset, self.data_preproc, self.txt_preproc, data_augmenter, n_augmentations)
        self.validation_dataset = InputDataset(validation_dataset, self.data_preproc, self.txt_preproc) if validation_dataset else None
        self.preload_training = preload_training
        self.preload_validation = preload_validation

        if len(self.dataset) == 0:
            raise Exception("Dataset is empty.")

        if self.validation_dataset and len(self.validation_dataset) == 0:
            raise Exception("Validation dataset is empty. Provide valid validation data for early stopping.")