Esempio n. 1
0
    def train(self, auto_compute_codec=False, progress_bar=False):
        """ Launch the training

        Parameters
        ----------
        auto_compute_codec : bool
            Compute the codec automatically based on the provided ground truth.
            Else provide a codec using a whitelist (faster).

        progress_bar : bool
            Show or hide any progress bar

        """
        checkpoint_params = self.checkpoint_params

        train_start_time = time.time() + self.checkpoint_params.total_time

        # load training dataset
        if self.preload_training:
            self.dataset.preload(processes=checkpoint_params.processes,
                                 progress_bar=progress_bar)

        # load validation dataset
        if self.validation_dataset and self.preload_validation:
            self.validation_dataset.preload(
                processes=checkpoint_params.processes,
                progress_bar=progress_bar)

        # compute the codec
        if self.codec:
            codec = self.codec
        else:
            if len(self.codec_whitelist) == 0 or auto_compute_codec:
                codec = Codec.from_input_dataset(
                    [self.dataset, self.validation_dataset],
                    whitelist=self.codec_whitelist,
                    progress_bar=progress_bar)
            else:
                codec = Codec.from_texts([], whitelist=self.codec_whitelist)

        # create backend
        network_params = checkpoint_params.model.network
        network_params.features = checkpoint_params.model.line_height
        network_params.classes = len(codec)
        if self.weights:
            # if we load the weights, take care of codec changes as-well
            ckpt = Checkpoint(self.weights + '.json',
                              auto_update=self.auto_update_checkpoints)
            restore_checkpoint_params = ckpt.checkpoint
            restore_model_params = restore_checkpoint_params.model

            # checks
            if checkpoint_params.model.line_height != network_params.features:
                raise Exception(
                    "The model to restore has a line height of {} but a line height of {} is requested"
                    .format(network_params.features,
                            checkpoint_params.model.line_height))

            # create codec of the same type
            restore_codec = codec.__class__(restore_model_params.codec.charset)
            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(codec)
            codec = restore_codec
            print("Codec changes: {} deletions, {} appends".format(
                len(codec_changes[0]), len(codec_changes[1])))
            # The actual weight/bias matrix will be changed after loading the old weights
            if all([c == 0 for c in codec_changes]):
                codec_changes = None  # No codec changes
        else:
            codec_changes = None

        # store the new codec
        checkpoint_params.model.codec.charset[:] = codec.charset
        print("CODEC: {}".format(codec.charset))

        backend = create_backend_from_proto(
            network_params,
            weights=self.weights,
        )
        train_net = backend.create_net(self.dataset,
                                       codec,
                                       restore=None,
                                       weights=self.weights,
                                       graph_type="train",
                                       batch_size=checkpoint_params.batch_size)
        test_net = backend.create_net(self.validation_dataset,
                                      codec,
                                      restore=None,
                                      weights=self.weights,
                                      graph_type="test",
                                      batch_size=checkpoint_params.batch_size)
        if codec_changes:
            # only required on one net, since the other shares the same variables
            train_net.realign_model_labels(*codec_changes)

        train_net.prepare()
        test_net.prepare()

        if checkpoint_params.current_stage == 0:
            self._run_train(train_net, test_net, codec, train_start_time,
                            progress_bar)

        if checkpoint_params.data_aug_retrain_on_original and self.dataset.data_augmenter and self.dataset.data_augmentation_amount > 0:
            print("Starting training on original data only")
            if checkpoint_params.current_stage == 0:
                checkpoint_params.current_stage = 1
                checkpoint_params.iter = 0
                checkpoint_params.early_stopping_best_at_iter = 0
                checkpoint_params.early_stopping_best_cur_nbest = 0
                checkpoint_params.early_stopping_best_accuracy = 0

            self.dataset.generate_only_non_augmented = True  # this is the important line!
            train_net.prepare()
            test_net.prepare()
            self._run_train(train_net, test_net, codec, train_start_time,
                            progress_bar)

        train_net.prepare()  # reset the state
        test_net.prepare()  # to prevent blocking of tensorflow on shutdown
Esempio n. 2
0
    def train(self, callbacks=None, **kwargs):
        callbacks = callbacks if callbacks else []
        self.setup_data()

        # load preloaded dataset
        data: Data = self._data
        model: ModelParams = self.scenario.params.model

        use_training_as_validation = model.ensemble > 0 or self.params.gen.__class__ == CalamariTrainOnlyPipelineParams

        # Setup train pipeline
        train_pipeline = self.params.gen.train_data(data)
        if len(train_pipeline.create_data_generator()) == 0:
            raise ValueError("Training dataset is empty.")

        # Setup validation pipeline
        val_pipeline = None
        if self.params.gen.val_gen():
            if model.ensemble > 0:
                logger.warning(
                    "A validation dataset can not be used when training and ensemble. "
                    "Only a training set is required. Ignoring validation data!"
                )
            else:
                val_pipeline = self.params.gen.val_data(data)
                if len(val_pipeline.create_data_generator()) == 0:
                    raise ValueError(
                        "Validation dataset is empty. Provide valid validation data for early stopping. "
                        "Alternative select train only data generator mode.")

        if self.params.gen.train_data(data).generator_params.preload:
            # preload before codec was created (not all processors can be applied, yet)
            data.preload(progress_bar=self._params.progress_bar)
            train_pipeline = self.params.gen.train_data(data)
            if val_pipeline:
                val_pipeline = self.params.gen.val_data(data)

        # compute the codec
        codec = data.params.codec
        if not codec:
            if self._params.codec.auto_compute or len(
                    self._params.codec.resolved_include_chars) == 0:
                codec = Codec.from_input_dataset(
                    filter(lambda x: x, [train_pipeline, val_pipeline]),
                    codec_construction_params=self._params.codec,
                    progress_bar=self._params.progress_bar,
                )
            else:
                codec = Codec(list(self._params.codec.resolved_include_chars))

        data.params.codec = codec
        model.classes = codec.size()

        if self.checkpoint:
            # if we load the weights, take care of codec changes as-well
            restore_checkpoint_params = self.checkpoint.dict
            restore_data_params = restore_checkpoint_params["scenario"]["data"]

            # checks
            if data.params.line_height != restore_data_params["line_height"]:
                raise ValueError(
                    f"The model to restore has a line height of {restore_data_params.line_height}"
                    f" but a line height of {data.params.line_height} is requested"
                )

            # create codec of the same type
            restore_codec = codec.__class__(
                restore_data_params["codec"]["charset"])

            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(
                codec, shrink=not self._params.codec.keep_loaded)
            codec = restore_codec
            logger.info(
                f"Codec changes: {len(codec_changes[0])} deletions, {len(codec_changes[1])} appends"
            )
            # The actual weight/bias matrix will be changed after loading the old weights
            if not any(codec_changes):
                codec_changes = None  # No codec changes

            self._codec_changes = codec_changes

        model.classes = codec.size()
        data.params.codec = codec
        logger.info(f"CODEC: {codec.charset}")

        if self.params.gen.train_data(data).generator_params.preload:
            # preload after codec was created
            data.preload(progress_bar=self._params.progress_bar)
            train_pipeline = self.params.gen.train_data(data)

        if use_training_as_validation:
            logger.info("Using training data for validation.")
            assert val_pipeline is None
            if self._params.gen.train.preload:
                data._pipelines[PipelineMode.EVALUATION] = RawDataPipeline(
                    [
                        s for s in train_pipeline.samples
                        if not s.meta["augmented"]
                    ],
                    pipeline_params=self._params.gen.setup.train,
                    data_base=data,
                    generator_params=train_pipeline.generator_params,
                    input_processors=train_pipeline._input_processors,
                    output_processors=train_pipeline._output_processors,
                ).to_mode(PipelineMode.EVALUATION)
            else:
                data._pipelines[
                    PipelineMode.EVALUATION] = train_pipeline.to_mode(
                        PipelineMode.EVALUATION)
        else:
            if val_pipeline is None:
                raise ValueError(
                    "No validation data provided."
                    "Set 'trainer.gen TrainOnly' to pass only training data."
                    "Validation will be performed on the training data in this case."
                    "Alternatively, set 'trainer.gen SplitTrain' and to use by "
                    "default 20% of the training data for validation")

        last_logs = None
        if self._params.current_stage == 0:
            last_logs = super(Trainer, self).train(callbacks=callbacks, )

        data_aug = self._data.params.pre_proc.processors_of_type(
            AugmentationProcessorParams)
        if (self._params.data_aug_retrain_on_original and len(data_aug) > 0
                and any(p.n_augmentations != 0 for p in data_aug)):
            logger.info("Starting training on original data only")
            if self._params.current_stage == 0:
                self._params.current_epoch = 0
                self._params.current_stage = 1
                self._params.early_stopping.current = 1  # CER = 100% as initial value
                self._params.early_stopping.n = 0

            # Remove data augmenter
            self._data.params.pre_proc.erase_all(AugmentationProcessorParams)
            # Remove augmented samples if 'preloaded"
            if isinstance(train_pipeline, RawDataPipeline):
                train_pipeline.samples = [
                    s for s in train_pipeline.samples
                    if not s.meta.get("augmented", False)
                ]

            logger.info(
                f"Training on {len(train_pipeline.create_data_generator())} samples."
            )

            super(Trainer, self).setup_steps_per_epoch()

            # replace callbacks that require steps per epoch as parameter
            first = True
            for i, cb in enumerate(self._callbacks[:]):
                if isinstance(cb, TensorBoardCallback):
                    cb.steps_per_epoch = self._steps_per_epoch

                if isinstance(cb, TrainerCheckpointsCallback):
                    if first:
                        self._callbacks[
                            i] = self.create_train_params_logger_callback(
                                store_params=False, store_weights=True)
                        first = False
                    else:
                        self._callbacks[
                            i] = self.create_train_params_logger_callback(
                                store_params=True, store_weights=False)
            logger_callback = next(c for c in self._callbacks
                                   if isinstance(c, LoggerCallback))
            super(Trainer, self).fit()
            last_logs = logger_callback.last_logs

        logger.info("Training finished")
        return last_logs
Esempio n. 3
0
    def train(self, auto_compute_codec=False, progress_bar=False, training_callback=ConsoleTrainingCallback()):
        """ Launch the training

        Parameters
        ----------
        auto_compute_codec : bool
            Compute the codec automatically based on the provided ground truth.
            Else provide a codec using a whitelist (faster).

        progress_bar : bool
            Show or hide any progress bar

        training_callback : TrainingCallback
            Callback for the training process (e.g., for displaying the current cer, loss in the console)

        """
        with ExitStackWithPop() as exit_stack:
            checkpoint_params = self.checkpoint_params

            train_start_time = time.time() + self.checkpoint_params.total_time

            exit_stack.enter_context(self.dataset)
            if self.validation_dataset:
                exit_stack.enter_context(self.validation_dataset)

            # load training dataset
            if self.preload_training:
                new_dataset = self.dataset.to_raw_input_dataset(processes=checkpoint_params.processes, progress_bar=progress_bar)
                exit_stack.pop(self.dataset)
                self.dataset = new_dataset
                exit_stack.enter_context(self.dataset)

            # load validation dataset
            if self.validation_dataset and self.preload_validation:
                new_dataset = self.validation_dataset.to_raw_input_dataset(processes=checkpoint_params.processes, progress_bar=progress_bar)
                exit_stack.pop(self.validation_dataset)
                self.validation_dataset = new_dataset
                exit_stack.enter_context(self.validation_dataset)

            # compute the codec
            if self.codec:
                codec = self.codec
            else:
                if len(self.codec_whitelist) == 0 or auto_compute_codec:
                    codec = Codec.from_input_dataset([self.dataset, self.validation_dataset],
                                                     whitelist=self.codec_whitelist, progress_bar=progress_bar)
                else:
                    codec = Codec.from_texts([], whitelist=self.codec_whitelist)

            # create backend
            network_params = checkpoint_params.model.network
            network_params.features = checkpoint_params.model.line_height
            if self.weights:
                # if we load the weights, take care of codec changes as-well
                ckpt = Checkpoint(self.weights + '.json', auto_update=self.auto_update_checkpoints)
                restore_checkpoint_params = ckpt.checkpoint
                restore_model_params = restore_checkpoint_params.model

                # checks
                if checkpoint_params.model.line_height != network_params.features:
                    raise Exception("The model to restore has a line height of {} but a line height of {} is requested".format(
                        network_params.features, checkpoint_params.model.line_height
                    ))

                # create codec of the same type
                restore_codec = codec.__class__(restore_model_params.codec.charset)

                # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
                codec_changes = restore_codec.align(codec, shrink=not self.keep_loaded_codec)
                codec = restore_codec
                print("Codec changes: {} deletions, {} appends".format(len(codec_changes[0]), len(codec_changes[1])))
                # The actual weight/bias matrix will be changed after loading the old weights
                if all([c == 0 for c in codec_changes]):
                    codec_changes = None  # No codec changes
            else:
                codec_changes = None

            # store the new codec
            network_params.classes = len(codec)
            checkpoint_params.model.codec.charset[:] = codec.charset
            print("CODEC: {}".format(codec.charset))

            backend = create_backend_from_checkpoint(
                checkpoint_params=checkpoint_params,
                processes=checkpoint_params.processes,
            )
            train_net = backend.create_net(codec, graph_type="train",
                                           checkpoint_to_load=Checkpoint(self.weights) if self.weights else None,
                                           batch_size=checkpoint_params.batch_size, codec_changes=codec_changes)

            if checkpoint_params.current_stage == 0:
                self._run_train(train_net, train_start_time, progress_bar, self.dataset, self.validation_dataset, training_callback)

            if checkpoint_params.data_aug_retrain_on_original and self.data_augmenter and self.n_augmentations != 0:
                print("Starting training on original data only")
                if checkpoint_params.current_stage == 0:
                    checkpoint_params.current_stage = 1
                    checkpoint_params.iter = 0
                    checkpoint_params.early_stopping_best_at_iter = 0
                    checkpoint_params.early_stopping_best_cur_nbest = 0
                    checkpoint_params.early_stopping_best_accuracy = 0

                self.dataset.generate_only_non_augmented = True  # this is the important line!
                self._run_train(train_net, train_start_time, progress_bar, self.dataset, self.validation_dataset, training_callback)
Esempio n. 4
0
    def train(self, callbacks=None, **kwargs):
        callbacks = callbacks if callbacks else []

        # load preloaded dataset
        data: Data = self._data
        model_params: ModelParams = self.scenario.params.model_params

        train_pipeline = data.get_train_data()
        if len(train_pipeline.create_data_generator()) == 0:
            raise ValueError("Training dataset is empty.")

        if data.params().val:
            val_pipeline = data.get_val_data()
            if len(val_pipeline.create_data_generator()) == 0:
                raise ValueError(
                    "Validation dataset is empty. Provide valid validation data for early stopping."
                )
        else:
            val_pipeline = None

        if self._params.preload_training:
            # preload before codec was created (not all processors can be applied, yet)
            data.preload(progress_bar=self._params.progress_bar)
            train_pipeline = data.get_train_data()
            if val_pipeline:
                val_pipeline = data.get_val_data()

        # compute the codec
        codec = data.params().codec
        if not codec:
            if len(self._params.codec_whitelist
                   ) == 0 or self._params.auto_compute_codec:
                codec = Codec.from_input_dataset(
                    filter(lambda x: x, [train_pipeline, val_pipeline]),
                    whitelist=self._params.codec_whitelist,
                    progress_bar=self._params.progress_bar)
            else:
                codec = Codec.from_texts(
                    [], whitelist=self._params.codec_whitelist)

        data.params().codec = codec
        data.params(
        ).downscale_factor_ = model_params.compute_downscale_factor()
        model_params.classes = codec.size()

        if self.checkpoint:
            # if we load the weights, take care of codec changes as-well
            restore_checkpoint_params = self.checkpoint.dict
            restore_data_params = restore_checkpoint_params['scenario_params'][
                'data_params']

            # checks
            if data.params(
            ).line_height_ != restore_data_params['line_height_']:
                raise ValueError(
                    f"The model to restore has a line height of {restore_data_params.line_height_}"
                    f" but a line height of {data.params().line_height_} is requested"
                )

            # create codec of the same type
            restore_codec = codec.__class__(
                restore_data_params['codec']['charset'])

            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(
                codec, shrink=not self._params.keep_loaded_codec)
            codec = restore_codec
            logger.info(
                f"Codec changes: {len(codec_changes[0])} deletions, {len(codec_changes[1])} appends"
            )
            # The actual weight/bias matrix will be changed after loading the old weights
            if not any(codec_changes):
                codec_changes = None  # No codec changes
        else:
            codec_changes = None

        model_params.classes = codec.size()
        data.params().codec = codec
        logger.info(f"CODEC: {codec.charset}")

        if self._params.preload_training:
            # preload after codec was created
            data.preload(progress_bar=self._params.progress_bar)
            train_pipeline = data.get_train_data()

        if self._params.current_stage == 0:
            super(Trainer, self).train(
                callbacks=callbacks,
                warmstart_fn=partial(WarmstarterWithCodecAdaption,
                                     codec_changes=codec_changes),
            )

        if self._params.data_aug_retrain_on_original and self._params.scenario_params.data_params.data_aug_params.to_abs(
        ) > 0:
            logger.info("Starting training on original data only")
            if self._params.current_stage == 0:
                self._params.current_epoch = 0
                self._params.current_stage = 1
                self._params.early_stopping_params.current_ = 1  # CER = 100% as initial value
                self._params.early_stopping_params.n_ = 0

            # Remove data augmenter
            self._data.params().pre_processors_.sample_processors = [
                p
                for p in self._data.params().pre_processors_.sample_processors
                if p.name != AugmentationProcessor.__name__
            ]
            # Remove augmented samples if 'preloaded"
            if isinstance(train_pipeline, RawDataPipeline):
                train_pipeline.samples = [
                    s for s in train_pipeline.samples
                    if not s.meta.get('augmented', False)
                ]

            logger.info(
                f"Training on {len(train_pipeline.create_data_generator())} samples."
            )

            super(Trainer, self).setup_steps_per_epoch()

            # replace callbacks that require steps per epoch as parameter
            tb_callback = next(cb for cb in self._callbacks
                               if isinstance(cb, TensorBoardCallback))
            tb_callback.steps_per_epoch = self._steps_per_epoch
            i = next(i for i, cb in enumerate(self._callbacks)
                     if isinstance(cb, TrainParamsLoggerCallback))
            del self._callbacks[i]
            self._callbacks.insert(i,
                                   self.create_train_params_logger_callback())

            super(Trainer, self).fit()

        logger.info("Training finished")
Esempio n. 5
0
    def train(self, auto_compute_codec=False, progress_bar=False):
        """ Launch the training

        Parameters
        ----------
        progress_bar : bool
            Show or hide any progress bar

        """
        checkpoint_params = self.checkpoint_params

        train_start_time = time.time() + self.checkpoint_params.total_time

        # load training dataset
        if self.preload_training:
            self.dataset.preload(processes=checkpoint_params.processes, progress_bar=progress_bar)

        # load validation dataset
        if self.validation_dataset and self.preload_validation:
            self.validation_dataset.preload(processes=checkpoint_params.processes, progress_bar=progress_bar)

        # compute the codec
        if self.codec:
            codec = self.codec
        else:
            if len(self.codec_whitelist) == 0 or auto_compute_codec:
                codec = Codec.from_input_dataset([self.dataset, self.validation_dataset],
                                                 whitelist=self.codec_whitelist, progress_bar=True)
            else:
                codec = Codec.from_texts([], whitelist=self.codec_whitelist)

        # create backend
        network_params = checkpoint_params.model.network
        network_params.features = checkpoint_params.model.line_height
        network_params.classes = len(codec)
        if self.weights:
            # if we load the weights, take care of codec changes as-well
            ckpt = Checkpoint(self.weights + '.json', auto_update=self.auto_update_checkpoints)
            restore_checkpoint_params = ckpt.checkpoint
            restore_model_params = restore_checkpoint_params.model

            # checks
            if checkpoint_params.model.line_height != network_params.features:
                raise Exception("The model to restore has a line height of {} but a line height of {} is requested".format(
                    network_params.features, checkpoint_params.model.line_height
                ))

            # create codec of the same type
            restore_codec = codec.__class__(restore_model_params.codec.charset)
            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(codec)
            codec = restore_codec
            print("Codec changes: {} deletions, {} appends".format(len(codec_changes[0]), len(codec_changes[1])))
            # The actual weight/bias matrix will be changed after loading the old weights
            if all([c == 0 for c in codec_changes]):
                codec_changes = None  # No codec changes
        else:
            codec_changes = None

        # store the new codec
        checkpoint_params.model.codec.charset[:] = codec.charset
        print("CODEC: {}".format(codec.charset))

        backend = create_backend_from_proto(network_params,
                                            weights=self.weights,
                                            )
        train_net = backend.create_net(self.dataset, codec, restore=None, weights=self.weights, graph_type="train", batch_size=checkpoint_params.batch_size)
        test_net = backend.create_net(self.validation_dataset, codec, restore=None, weights=self.weights, graph_type="test", batch_size=checkpoint_params.batch_size)
        if codec_changes:
            # only required on one net, since the other shares the same variables
            train_net.realign_model_labels(*codec_changes)

        train_net.prepare()
        test_net.prepare()

        if checkpoint_params.current_stage == 0:
            self._run_train(train_net, test_net, codec, train_start_time, progress_bar)

        if checkpoint_params.data_aug_retrain_on_original and self.dataset.data_augmenter and self.dataset.data_augmentation_amount > 0:
            print("Starting training on original data only")
            if checkpoint_params.current_stage == 0:
                checkpoint_params.current_stage = 1
                checkpoint_params.iter = 0
                checkpoint_params.early_stopping_best_at_iter = 0
                checkpoint_params.early_stopping_best_cur_nbest = 0
                checkpoint_params.early_stopping_best_accuracy = 0

            self.dataset.generate_only_non_augmented = True  # this is the important line!
            train_net.prepare()
            test_net.prepare()
            self._run_train(train_net, test_net, codec, train_start_time, progress_bar)

        train_net.prepare()  # reset the state
        test_net.prepare()   # to prevent blocking of tensorflow on shutdown