Esempio n. 1
0
    def predict_raw(self, datas, batch_size=1, processes=1, progress_bar=True, apply_preproc=True):
        # preprocessing step
        if apply_preproc:
            datas = self.data_preproc.apply(datas, processes=processes, progress_bar=progress_bar)

        codec = self.codec if self.codec else Codec(self.model_params.codec.charset)

        # create backend
        self.backend.set_prediction_data(datas)
        self.backend.prepare(train=False)

        if progress_bar:
            out = tqdm(self.backend.prediction_step(batch_size), desc="Prediction", total=len(datas))
        else:
            out = self.backend.prediction_step(batch_size)

        for p in out:
            yield PredictionResult(p, codec=codec, text_postproc=self.text_postproc)
Esempio n. 2
0
    def predict_raw(self,
                    datas,
                    batch_size=1,
                    processes=1,
                    progress_bar=True,
                    apply_preproc=True):
        # preprocessing step
        if apply_preproc:
            datas = self.data_preproc.apply(datas,
                                            processes=processes,
                                            progress_bar=progress_bar)

        codec = self.codec if self.codec else Codec(
            self.model_params.codec.charset)

        # create backend
        self.backend.set_prediction_data(datas)
        self.backend.prepare(train=False)

        prediction_start_time = time.time()

        if progress_bar:
            out = list(
                tqdm(self.backend.prediction_step(batch_size),
                     desc="Prediction",
                     total=self.backend.num_prediction_steps(batch_size)))
        else:
            out = list(self.backend.prediction_step(batch_size))

        prediction_results = [
            PredictionResult(
                p,
                codec=codec,
                text_postproc=self.text_postproc,
            ) for p in out
        ]

        return prediction_results, time.time() - prediction_start_time
Esempio n. 3
0
    def __init__(self, checkpoint=None, text_postproc=None, data_preproc=None, codec=None, network=None,
                 batch_size=1, processes=1,
                 auto_update_checkpoints=True,
                 with_gt=False,
                 ):
        """ Predicting a dataset based on a trained model

        Parameters
        ----------
        checkpoint : str, optional
            filepath of the checkpoint of the network to load, alternatively you can directly use a loaded `network`
        text_postproc : TextProcessor, optional
            text processor to be applied on the predicted sentence for the final output.
            If loaded from a checkpoint the text processor will be loaded from it.
        data_preproc : DataProcessor, optional
            data processor (must be the same as of the trained model) to be applied to the input image.
            If loaded from a checkpoint the text processor will be loaded from it.
        codec : Codec, optional
            Codec of the deep net to use for decoding. This parameter is only required if a custom codec is used,
            or a `network` has been provided instead of a `checkpoint`
        network : ModelInterface, optional
            DNN instance to used. Alternatively you can provide a `checkpoint` to load a network.
        batch_size : int, optional
            Batch size to use for prediction
        processes : int, optional
            The number of processes to use for prediction
        auto_update_checkpoints : bool, optional
            Update old models automatically (this will change the checkpoint files)
        with_gt : bool, optional
            The prediction will also output the ground truth if available else None
        """
        self.network = network
        self.checkpoint = checkpoint
        self.processes = processes
        self.auto_update_checkpoints = auto_update_checkpoints
        self.with_gt = with_gt

        if checkpoint:
            if network:
                raise Exception("Either a checkpoint or a network can be provided")

            ckpt = Checkpoint(checkpoint, auto_update=self.auto_update_checkpoints)
            self.checkpoint = ckpt.ckpt_path
            checkpoint_params = ckpt.checkpoint
            self.model_params = checkpoint_params.model
            self.codec = codec if codec else Codec(self.model_params.codec.charset)

            self.network_params = self.model_params.network
            backend = create_backend_from_proto(self.network_params, restore=self.checkpoint, processes=processes)
            self.text_postproc = text_postproc if text_postproc else text_processor_from_proto(self.model_params.text_postprocessor, "post")
            self.data_preproc = data_preproc if data_preproc else data_processor_from_proto(self.model_params.data_preprocessor)
            self.network = backend.create_net(
                dataset=None,
                codec=self.codec,
                restore=self.checkpoint, weights=None, graph_type="predict", batch_size=batch_size)
        elif network:
            self.codec = codec
            self.model_params = None
            self.network_params = network.network_proto
            self.text_postproc = text_postproc
            self.data_preproc = data_preproc
            if not codec:
                raise Exception("A codec is required if preloaded network is used.")
        else:
            raise Exception("Either a checkpoint or a existing backend must be provided")

        self.out_to_in_trans = OutputToInputTransformer(self.data_preproc, self.network)
Esempio n. 4
0
    def train(self, progress_bar=False):
        checkpoint_params = self.checkpoint_params

        train_start_time = time.time() + self.checkpoint_params.total_time

        self.dataset.load_samples(processes=1, progress_bar=progress_bar)
        datas, txts = self.dataset.train_samples(skip_empty=checkpoint_params.skip_invalid_gt)
        if len(datas) == 0:
            raise Exception("Empty dataset is not allowed. Check if the data is at the correct location")

        if self.validation_dataset:
            self.validation_dataset.load_samples(processes=1, progress_bar=progress_bar)
            validation_datas, validation_txts = self.validation_dataset.train_samples(skip_empty=checkpoint_params.skip_invalid_gt)
            if len(validation_datas) == 0:
                raise Exception("Validation dataset is empty. Provide valid validation data for early stopping.")
        else:
            validation_datas, validation_txts = [], []


        # preprocessing steps
        texts = self.txt_preproc.apply(txts, processes=checkpoint_params.processes, progress_bar=progress_bar)
        datas = self.data_preproc.apply(datas, processes=checkpoint_params.processes, progress_bar=progress_bar)
        validation_txts = self.txt_preproc.apply(validation_txts, processes=checkpoint_params.processes, progress_bar=progress_bar)
        validation_datas = self.data_preproc.apply(validation_datas, processes=checkpoint_params.processes, progress_bar=progress_bar)

        # compute the codec
        codec = self.codec if self.codec else Codec.from_texts(texts, whitelist=self.codec_whitelist)

        # data augmentation on preprocessed data
        if self.data_augmenter:
            datas, texts = self.data_augmenter.augment_datas(datas, texts, n_augmentations=self.n_augmentations,
                                                             processes=checkpoint_params.processes, progress_bar=progress_bar)

            # TODO: validation data augmentation
            # validation_datas, validation_txts = self.data_augmenter.augment_datas(validation_datas, validation_txts, n_augmentations=0,
            #                                                  processes=checkpoint_params.processes, progress_bar=progress_bar)

        # create backend
        network_params = checkpoint_params.model.network
        network_params.features = checkpoint_params.model.line_height
        network_params.classes = len(codec)
        if self.weights:
            # if we load the weights, take care of codec changes as-well
            with open(self.weights + '.json', 'r') as f:
                restore_checkpoint_params = json_format.Parse(f.read(), CheckpointParams())
                restore_model_params = restore_checkpoint_params.model

            # checks
            if checkpoint_params.model.line_height != network_params.features:
                raise Exception("The model to restore has a line height of {} but a line height of {} is requested".format(
                    network_params.features, checkpoint_params.model.line_height
                ))

            # create codec of the same type
            restore_codec = codec.__class__(restore_model_params.codec.charset)
            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(codec)
            codec = restore_codec
            print("Codec changes: {} deletions, {} appends".format(len(codec_changes[0]), len(codec_changes[1])))
            # The actual weight/bias matrix will be changed after loading the old weights
        else:
            codec_changes = None

        # store the new codec
        checkpoint_params.model.codec.charset[:] = codec.charset
        print("CODEC: {}".format(codec.charset))

        # compute the labels with (new/current) codec
        labels = [codec.encode(txt) for txt in texts]

        backend = create_backend_from_proto(network_params,
                                            weights=self.weights,
                                            )
        backend.set_train_data(datas, labels)
        backend.set_prediction_data(validation_datas)
        if codec_changes:
            backend.realign_model_labels(*codec_changes)
        backend.prepare(train=True)

        loss_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.loss_stats)
        ler_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.ler_stats)
        dt_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.dt_stats)

        early_stopping_enabled = self.validation_dataset is not None \
                                 and checkpoint_params.early_stopping_frequency > 0 \
                                 and checkpoint_params.early_stopping_nbest > 1
        early_stopping_best_accuracy = checkpoint_params.early_stopping_best_accuracy
        early_stopping_best_cur_nbest = checkpoint_params.early_stopping_best_cur_nbest
        early_stopping_best_at_iter = checkpoint_params.early_stopping_best_at_iter

        early_stopping_predictor = Predictor(codec=codec, text_postproc=self.txt_postproc,
                                             backend=backend)

        # Start the actual training
        # ====================================================================================

        iter = checkpoint_params.iter

        # helper function to write a checkpoint
        def make_checkpoint(base_dir, prefix, version=None):
            if version:
                checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{}.ckpt".format(prefix, version)))
            else:
                checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{:08d}.ckpt".format(prefix, iter + 1)))
            print("Storing checkpoint to '{}'".format(checkpoint_path))
            backend.save_checkpoint(checkpoint_path)
            checkpoint_params.iter = iter
            checkpoint_params.loss_stats[:] = loss_stats.values
            checkpoint_params.ler_stats[:] = ler_stats.values
            checkpoint_params.dt_stats[:] = dt_stats.values
            checkpoint_params.total_time = time.time() - train_start_time
            checkpoint_params.early_stopping_best_accuracy = early_stopping_best_accuracy
            checkpoint_params.early_stopping_best_cur_nbest = early_stopping_best_cur_nbest
            checkpoint_params.early_stopping_best_at_iter = early_stopping_best_at_iter

            with open(checkpoint_path + ".json", 'w') as f:
                f.write(json_format.MessageToJson(checkpoint_params))

            return checkpoint_path

        try:
            last_checkpoint = None

            # Training loop, can be interrupted by early stopping
            for iter in range(iter, checkpoint_params.max_iters):
                checkpoint_params.iter = iter

                iter_start_time = time.time()
                result = backend.train_step(checkpoint_params.batch_size)

                if not np.isfinite(result['loss']):
                    print("Error: Loss is not finite! Trying to restart from last checkpoint.")
                    if not last_checkpoint:
                        raise Exception("No checkpoint written yet. Training must be stopped.")
                    else:
                        # reload also non trainable weights, such as solver-specific variables
                        backend.load_checkpoint_weights(last_checkpoint, restore_only_trainable=False)
                        continue

                loss_stats.push(result['loss'])
                ler_stats.push(result['ler'])

                dt_stats.push(time.time() - iter_start_time)

                if iter % checkpoint_params.display == 0:
                    pred_sentence = self.txt_postproc.apply("".join(codec.decode(result["decoded"][0])))
                    gt_sentence = self.txt_postproc.apply("".join(codec.decode(result["gt"][0])))
                    print("#{:08d}: loss={:.8f} ler={:.8f} dt={:.8f}s".format(iter, loss_stats.mean(), ler_stats.mean(), dt_stats.mean()))
                    print(" PRED: '{}'".format(pred_sentence))
                    print(" TRUE: '{}'".format(gt_sentence))

                if (iter + 1) % checkpoint_params.checkpoint_frequency == 0:
                    last_checkpoint = make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix)

                if early_stopping_enabled and (iter + 1) % checkpoint_params.early_stopping_frequency == 0:
                    print("Checking early stopping model")

                    out = early_stopping_predictor.predict_raw(validation_datas, batch_size=checkpoint_params.batch_size,
                                                               progress_bar=progress_bar, apply_preproc=False)
                    pred_texts = [d.sentence for d in out]
                    result = Evaluator.evaluate(gt_data=validation_txts, pred_data=pred_texts, progress_bar=progress_bar)
                    accuracy = 1 - result["avg_ler"]

                    if accuracy > early_stopping_best_accuracy:
                        early_stopping_best_accuracy = accuracy
                        early_stopping_best_cur_nbest = 1
                        early_stopping_best_at_iter = iter + 1
                        # overwrite as best model
                        last_checkpoint = make_checkpoint(
                            checkpoint_params.early_stopping_best_model_output_dir,
                            prefix="",
                            version=checkpoint_params.early_stopping_best_model_prefix,
                        )
                        print("Found better model with accuracy of {:%}".format(early_stopping_best_accuracy))
                    else:
                        early_stopping_best_cur_nbest += 1
                        print("No better model found. Currently accuracy of {:%} at iter {} (remaining nbest = {})".
                              format(early_stopping_best_accuracy, early_stopping_best_at_iter,
                                     checkpoint_params.early_stopping_nbest - early_stopping_best_cur_nbest))

                    if accuracy > 0 and early_stopping_best_cur_nbest >= checkpoint_params.early_stopping_nbest:
                        print("Early stopping now.")
                        break

        except KeyboardInterrupt as e:
            print("Storing interrupted checkpoint")
            make_checkpoint(checkpoint_params.output_dir,
                            checkpoint_params.output_model_prefix,
                            "interrupted")
            raise e

        print("Total time {}s for {} iterations.".format(time.time() - train_start_time, iter))
Esempio n. 5
0

if __name__ == "__main__":
    from calamari_ocr.ocr import Codec

    this_dir = os.path.dirname(os.path.realpath(__file__))
    base_path = os.path.abspath(os.path.join(this_dir, "..", "..", "test", "data", "uw3_50lines", "train"))

    fdr = FileDataParams(
        num_processes=8,
        images=[os.path.join(base_path, "*.png")],
        limit=1000,
    )

    params = DataParams(
        codec=Codec("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,:;-?+=_()*{}[]`@#$%^&'\""),
        downscale_factor=4,
        line_height=48,
        pre_proc=SequentialProcessorPipelineParams(
            run_parallel=True,
            processors=default_image_processors()
            + default_text_pre_processors()
            + [
                AugmentationProcessorParams(
                    modes={PipelineMode.TRAINING},
                    data_aug_params=DataAugmentationAmount(amount=2),
                ),
                PrepareSampleProcessorParams(modes=INPUT_PROCESSOR),
            ],
        ),
        post_proc=SequentialProcessorPipelineParams(run_parallel=False),
Esempio n. 6
0
    def train(self, auto_compute_codec=False, progress_bar=False):
        """ Launch the training

        Parameters
        ----------
        auto_compute_codec : bool
            Compute the codec automatically based on the provided ground truth.
            Else provide a codec using a whitelist (faster).

        progress_bar : bool
            Show or hide any progress bar

        """
        checkpoint_params = self.checkpoint_params

        train_start_time = time.time() + self.checkpoint_params.total_time

        # load training dataset
        if self.preload_training:
            self.dataset.preload(processes=checkpoint_params.processes,
                                 progress_bar=progress_bar)

        # load validation dataset
        if self.validation_dataset and self.preload_validation:
            self.validation_dataset.preload(
                processes=checkpoint_params.processes,
                progress_bar=progress_bar)

        # compute the codec
        if self.codec:
            codec = self.codec
        else:
            if len(self.codec_whitelist) == 0 or auto_compute_codec:
                codec = Codec.from_input_dataset(
                    [self.dataset, self.validation_dataset],
                    whitelist=self.codec_whitelist,
                    progress_bar=progress_bar)
            else:
                codec = Codec.from_texts([], whitelist=self.codec_whitelist)

        # create backend
        network_params = checkpoint_params.model.network
        network_params.features = checkpoint_params.model.line_height
        network_params.classes = len(codec)
        if self.weights:
            # if we load the weights, take care of codec changes as-well
            ckpt = Checkpoint(self.weights + '.json',
                              auto_update=self.auto_update_checkpoints)
            restore_checkpoint_params = ckpt.checkpoint
            restore_model_params = restore_checkpoint_params.model

            # checks
            if checkpoint_params.model.line_height != network_params.features:
                raise Exception(
                    "The model to restore has a line height of {} but a line height of {} is requested"
                    .format(network_params.features,
                            checkpoint_params.model.line_height))

            # create codec of the same type
            restore_codec = codec.__class__(restore_model_params.codec.charset)
            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(codec)
            codec = restore_codec
            print("Codec changes: {} deletions, {} appends".format(
                len(codec_changes[0]), len(codec_changes[1])))
            # The actual weight/bias matrix will be changed after loading the old weights
            if all([c == 0 for c in codec_changes]):
                codec_changes = None  # No codec changes
        else:
            codec_changes = None

        # store the new codec
        checkpoint_params.model.codec.charset[:] = codec.charset
        print("CODEC: {}".format(codec.charset))

        backend = create_backend_from_proto(
            network_params,
            weights=self.weights,
        )
        train_net = backend.create_net(self.dataset,
                                       codec,
                                       restore=None,
                                       weights=self.weights,
                                       graph_type="train",
                                       batch_size=checkpoint_params.batch_size)
        test_net = backend.create_net(self.validation_dataset,
                                      codec,
                                      restore=None,
                                      weights=self.weights,
                                      graph_type="test",
                                      batch_size=checkpoint_params.batch_size)
        if codec_changes:
            # only required on one net, since the other shares the same variables
            train_net.realign_model_labels(*codec_changes)

        train_net.prepare()
        test_net.prepare()

        if checkpoint_params.current_stage == 0:
            self._run_train(train_net, test_net, codec, train_start_time,
                            progress_bar)

        if checkpoint_params.data_aug_retrain_on_original and self.dataset.data_augmenter and self.dataset.data_augmentation_amount > 0:
            print("Starting training on original data only")
            if checkpoint_params.current_stage == 0:
                checkpoint_params.current_stage = 1
                checkpoint_params.iter = 0
                checkpoint_params.early_stopping_best_at_iter = 0
                checkpoint_params.early_stopping_best_cur_nbest = 0
                checkpoint_params.early_stopping_best_accuracy = 0

            self.dataset.generate_only_non_augmented = True  # this is the important line!
            train_net.prepare()
            test_net.prepare()
            self._run_train(train_net, test_net, codec, train_start_time,
                            progress_bar)

        train_net.prepare()  # reset the state
        test_net.prepare()  # to prevent blocking of tensorflow on shutdown
Esempio n. 7
0
    def train(self, progress_bar=False):
        """ Launch the training

        Parameters
        ----------
        progress_bar : bool
            Show or hide any progress bar

        """
        checkpoint_params = self.checkpoint_params

        train_start_time = time.time() + self.checkpoint_params.total_time

        self.dataset.load_samples(processes=1, progress_bar=progress_bar)
        datas, txts = self.dataset.train_samples(
            skip_empty=checkpoint_params.skip_invalid_gt)
        if len(datas) == 0:
            raise Exception(
                "Empty dataset is not allowed. Check if the data is at the correct location"
            )

        if self.validation_dataset:
            self.validation_dataset.load_samples(processes=1,
                                                 progress_bar=progress_bar)
            validation_datas, validation_txts = self.validation_dataset.train_samples(
                skip_empty=checkpoint_params.skip_invalid_gt)
            if len(validation_datas) == 0:
                raise Exception(
                    "Validation dataset is empty. Provide valid validation data for early stopping."
                )
        else:
            validation_datas, validation_txts = [], []

        # preprocessing steps
        texts = self.txt_preproc.apply(txts,
                                       processes=checkpoint_params.processes,
                                       progress_bar=progress_bar)
        datas, params = [
            list(a) for a in zip(
                *self.data_preproc.apply(datas,
                                         processes=checkpoint_params.processes,
                                         progress_bar=progress_bar))
        ]
        validation_txts = self.txt_preproc.apply(
            validation_txts,
            processes=checkpoint_params.processes,
            progress_bar=progress_bar)
        validation_data_params = self.data_preproc.apply(
            validation_datas,
            processes=checkpoint_params.processes,
            progress_bar=progress_bar)

        # compute the codec
        codec = self.codec if self.codec else Codec.from_texts(
            texts, whitelist=self.codec_whitelist)

        # store original data in case data augmentation is used with a second step
        original_texts = texts
        original_datas = datas

        # data augmentation on preprocessed data
        if self.data_augmenter:
            datas, texts = self.data_augmenter.augment_datas(
                datas,
                texts,
                n_augmentations=self.n_augmentations,
                processes=checkpoint_params.processes,
                progress_bar=progress_bar)

            # TODO: validation data augmentation
            # validation_datas, validation_txts = self.data_augmenter.augment_datas(validation_datas, validation_txts, n_augmentations=0,
            #                                                  processes=checkpoint_params.processes, progress_bar=progress_bar)

        # create backend
        network_params = checkpoint_params.model.network
        network_params.features = checkpoint_params.model.line_height
        network_params.classes = len(codec)
        if self.weights:
            # if we load the weights, take care of codec changes as-well
            ckpt = Checkpoint(self.weights + '.json',
                              auto_update=self.auto_update_checkpoints)
            restore_checkpoint_params = ckpt.checkpoint
            restore_model_params = restore_checkpoint_params.model

            # checks
            if checkpoint_params.model.line_height != network_params.features:
                raise Exception(
                    "The model to restore has a line height of {} but a line height of {} is requested"
                    .format(network_params.features,
                            checkpoint_params.model.line_height))

            # create codec of the same type
            restore_codec = codec.__class__(restore_model_params.codec.charset)
            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(codec)
            codec = restore_codec
            print("Codec changes: {} deletions, {} appends".format(
                len(codec_changes[0]), len(codec_changes[1])))
            # The actual weight/bias matrix will be changed after loading the old weights
            if all([c == 0 for c in codec_changes]):
                codec_changes = None  # No codec changes
        else:
            codec_changes = None

        # store the new codec
        checkpoint_params.model.codec.charset[:] = codec.charset
        print("CODEC: {}".format(codec.charset))

        # compute the labels with (new/current) codec
        labels = [codec.encode(txt) for txt in texts]

        backend = create_backend_from_proto(
            network_params,
            weights=self.weights,
        )
        train_net = backend.create_net(restore=None,
                                       weights=self.weights,
                                       graph_type="train",
                                       batch_size=checkpoint_params.batch_size)
        test_net = backend.create_net(restore=None,
                                      weights=self.weights,
                                      graph_type="test",
                                      batch_size=checkpoint_params.batch_size)
        train_net.set_data(datas, labels)
        test_net.set_data(validation_datas, validation_txts)
        if codec_changes:
            # only required on one net, since the other shares the same variables
            train_net.realign_model_labels(*codec_changes)

        train_net.prepare()
        test_net.prepare()

        if checkpoint_params.current_stage == 0:
            self._run_train(train_net, test_net, codec, validation_data_params,
                            train_start_time, progress_bar)

        if checkpoint_params.data_aug_retrain_on_original and self.data_augmenter and self.n_augmentations > 0:
            print("Starting training on original data only")
            if checkpoint_params.current_stage == 0:
                checkpoint_params.current_stage = 1
                checkpoint_params.iter = 0
                checkpoint_params.early_stopping_best_at_iter = 0
                checkpoint_params.early_stopping_best_cur_nbest = 0
                checkpoint_params.early_stopping_best_accuracy = 0

            train_net.set_data(original_datas,
                               [codec.encode(txt) for txt in original_texts])
            test_net.set_data(validation_datas, validation_txts)
            train_net.prepare()
            test_net.prepare()
            self._run_train(train_net, test_net, codec, validation_data_params,
                            train_start_time, progress_bar)
Esempio n. 8
0
    def train(self, callbacks=None, **kwargs):
        callbacks = callbacks if callbacks else []
        self.setup_data()

        # load preloaded dataset
        data: Data = self._data
        model: ModelParams = self.scenario.params.model

        use_training_as_validation = model.ensemble > 0 or self.params.gen.__class__ == CalamariTrainOnlyPipelineParams

        # Setup train pipeline
        train_pipeline = self.params.gen.train_data(data)
        if len(train_pipeline.create_data_generator()) == 0:
            raise ValueError("Training dataset is empty.")

        # Setup validation pipeline
        val_pipeline = None
        if self.params.gen.val_gen():
            if model.ensemble > 0:
                logger.warning(
                    "A validation dataset can not be used when training and ensemble. "
                    "Only a training set is required. Ignoring validation data!"
                )
            else:
                val_pipeline = self.params.gen.val_data(data)
                if len(val_pipeline.create_data_generator()) == 0:
                    raise ValueError(
                        "Validation dataset is empty. Provide valid validation data for early stopping. "
                        "Alternative select train only data generator mode.")

        if self.params.gen.train_data(data).generator_params.preload:
            # preload before codec was created (not all processors can be applied, yet)
            data.preload(progress_bar=self._params.progress_bar)
            train_pipeline = self.params.gen.train_data(data)
            if val_pipeline:
                val_pipeline = self.params.gen.val_data(data)

        # compute the codec
        codec = data.params.codec
        if not codec:
            if self._params.codec.auto_compute or len(
                    self._params.codec.resolved_include_chars) == 0:
                codec = Codec.from_input_dataset(
                    filter(lambda x: x, [train_pipeline, val_pipeline]),
                    codec_construction_params=self._params.codec,
                    progress_bar=self._params.progress_bar,
                )
            else:
                codec = Codec(list(self._params.codec.resolved_include_chars))

        data.params.codec = codec
        model.classes = codec.size()

        if self.checkpoint:
            # if we load the weights, take care of codec changes as-well
            restore_checkpoint_params = self.checkpoint.dict
            restore_data_params = restore_checkpoint_params["scenario"]["data"]

            # checks
            if data.params.line_height != restore_data_params["line_height"]:
                raise ValueError(
                    f"The model to restore has a line height of {restore_data_params.line_height}"
                    f" but a line height of {data.params.line_height} is requested"
                )

            # create codec of the same type
            restore_codec = codec.__class__(
                restore_data_params["codec"]["charset"])

            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(
                codec, shrink=not self._params.codec.keep_loaded)
            codec = restore_codec
            logger.info(
                f"Codec changes: {len(codec_changes[0])} deletions, {len(codec_changes[1])} appends"
            )
            # The actual weight/bias matrix will be changed after loading the old weights
            if not any(codec_changes):
                codec_changes = None  # No codec changes

            self._codec_changes = codec_changes

        model.classes = codec.size()
        data.params.codec = codec
        logger.info(f"CODEC: {codec.charset}")

        if self.params.gen.train_data(data).generator_params.preload:
            # preload after codec was created
            data.preload(progress_bar=self._params.progress_bar)
            train_pipeline = self.params.gen.train_data(data)

        if use_training_as_validation:
            logger.info("Using training data for validation.")
            assert val_pipeline is None
            if self._params.gen.train.preload:
                data._pipelines[PipelineMode.EVALUATION] = RawDataPipeline(
                    [
                        s for s in train_pipeline.samples
                        if not s.meta["augmented"]
                    ],
                    pipeline_params=self._params.gen.setup.train,
                    data_base=data,
                    generator_params=train_pipeline.generator_params,
                    input_processors=train_pipeline._input_processors,
                    output_processors=train_pipeline._output_processors,
                ).to_mode(PipelineMode.EVALUATION)
            else:
                data._pipelines[
                    PipelineMode.EVALUATION] = train_pipeline.to_mode(
                        PipelineMode.EVALUATION)
        else:
            if val_pipeline is None:
                raise ValueError(
                    "No validation data provided."
                    "Set 'trainer.gen TrainOnly' to pass only training data."
                    "Validation will be performed on the training data in this case."
                    "Alternatively, set 'trainer.gen SplitTrain' and to use by "
                    "default 20% of the training data for validation")

        last_logs = None
        if self._params.current_stage == 0:
            last_logs = super(Trainer, self).train(callbacks=callbacks, )

        data_aug = self._data.params.pre_proc.processors_of_type(
            AugmentationProcessorParams)
        if (self._params.data_aug_retrain_on_original and len(data_aug) > 0
                and any(p.n_augmentations != 0 for p in data_aug)):
            logger.info("Starting training on original data only")
            if self._params.current_stage == 0:
                self._params.current_epoch = 0
                self._params.current_stage = 1
                self._params.early_stopping.current = 1  # CER = 100% as initial value
                self._params.early_stopping.n = 0

            # Remove data augmenter
            self._data.params.pre_proc.erase_all(AugmentationProcessorParams)
            # Remove augmented samples if 'preloaded"
            if isinstance(train_pipeline, RawDataPipeline):
                train_pipeline.samples = [
                    s for s in train_pipeline.samples
                    if not s.meta.get("augmented", False)
                ]

            logger.info(
                f"Training on {len(train_pipeline.create_data_generator())} samples."
            )

            super(Trainer, self).setup_steps_per_epoch()

            # replace callbacks that require steps per epoch as parameter
            first = True
            for i, cb in enumerate(self._callbacks[:]):
                if isinstance(cb, TensorBoardCallback):
                    cb.steps_per_epoch = self._steps_per_epoch

                if isinstance(cb, TrainerCheckpointsCallback):
                    if first:
                        self._callbacks[
                            i] = self.create_train_params_logger_callback(
                                store_params=False, store_weights=True)
                        first = False
                    else:
                        self._callbacks[
                            i] = self.create_train_params_logger_callback(
                                store_params=True, store_weights=False)
            logger_callback = next(c for c in self._callbacks
                                   if isinstance(c, LoggerCallback))
            super(Trainer, self).fit()
            last_logs = logger_callback.last_logs

        logger.info("Training finished")
        return last_logs
Esempio n. 9
0
    def __init__(self,
                 checkpoint=None,
                 text_postproc=None,
                 data_preproc=None,
                 codec=None,
                 network=None,
                 batch_size=1,
                 processes=1):
        """ Predicting a dataset based on a trained model

        Parameters
        ----------
        checkpoint : str, optional
            filepath of the checkpoint of the network to load, alternatively you can directly use a loaded `network`
        text_postproc : TextProcessor, optional
            text processor to be applied on the predicted sentence for the final output.
            If loaded from a checkpoint the text processor will be loaded from it.
        data_preproc : DataProcessor, optional
            data processor (must be the same as of the trained model) to be applied to the input image.
            If loaded from a checkpoint the text processor will be loaded from it.
        codec : Codec, optional
            Codec of the deep net to use for decoding. This parameter is only required if a custom codec is used,
            or a `network` has been provided instead of a `checkpoint`
        network : ModelInterface, optional
            DNN instance to used. Alternatively you can provide a `checkpoint` to load a network.
        batch_size : int, optional
            Batch size to use for prediction
        processes : int, optional
            The number of processes to use for prediction
        """
        self.network = network
        self.checkpoint = checkpoint
        self.processes = processes

        if checkpoint:
            if network:
                raise Exception(
                    "Either a checkpoint or a network can be provided")

            with open(checkpoint + '.json', 'r') as f:
                checkpoint_params = json_format.Parse(f.read(),
                                                      CheckpointParams())
                self.model_params = checkpoint_params.model

            self.network_params = self.model_params.network
            backend = create_backend_from_proto(self.network_params,
                                                restore=self.checkpoint,
                                                processes=processes)
            self.network = backend.create_net(restore=self.checkpoint,
                                              weights=None,
                                              graph_type="predict",
                                              batch_size=batch_size)
            self.text_postproc = text_postproc if text_postproc else text_processor_from_proto(
                self.model_params.text_postprocessor, "post")
            self.data_preproc = data_preproc if data_preproc else data_processor_from_proto(
                self.model_params.data_preprocessor)
        elif network:
            self.model_params = None
            self.network_params = network.network_proto
            self.text_postproc = text_postproc
            self.data_preproc = data_preproc
            if not codec:
                raise Exception(
                    "A codec is required if preloaded network is used.")
        else:
            raise Exception(
                "Either a checkpoint or a existing backend must be provided")

        self.codec = codec if codec else Codec(self.model_params.codec.charset)
        self.out_to_in_trans = OutputToInputTransformer(
            self.data_preproc, self.network)
Esempio n. 10
0
    def train(self, callbacks=None, **kwargs):
        callbacks = callbacks if callbacks else []

        # load preloaded dataset
        data: Data = self._data
        model_params: ModelParams = self.scenario.params.model_params

        train_pipeline = data.get_train_data()
        if len(train_pipeline.create_data_generator()) == 0:
            raise ValueError("Training dataset is empty.")

        if data.params().val:
            val_pipeline = data.get_val_data()
            if len(val_pipeline.create_data_generator()) == 0:
                raise ValueError(
                    "Validation dataset is empty. Provide valid validation data for early stopping."
                )
        else:
            val_pipeline = None

        if self._params.preload_training:
            # preload before codec was created (not all processors can be applied, yet)
            data.preload(progress_bar=self._params.progress_bar)
            train_pipeline = data.get_train_data()
            if val_pipeline:
                val_pipeline = data.get_val_data()

        # compute the codec
        codec = data.params().codec
        if not codec:
            if len(self._params.codec_whitelist
                   ) == 0 or self._params.auto_compute_codec:
                codec = Codec.from_input_dataset(
                    filter(lambda x: x, [train_pipeline, val_pipeline]),
                    whitelist=self._params.codec_whitelist,
                    progress_bar=self._params.progress_bar)
            else:
                codec = Codec.from_texts(
                    [], whitelist=self._params.codec_whitelist)

        data.params().codec = codec
        data.params(
        ).downscale_factor_ = model_params.compute_downscale_factor()
        model_params.classes = codec.size()

        if self.checkpoint:
            # if we load the weights, take care of codec changes as-well
            restore_checkpoint_params = self.checkpoint.dict
            restore_data_params = restore_checkpoint_params['scenario_params'][
                'data_params']

            # checks
            if data.params(
            ).line_height_ != restore_data_params['line_height_']:
                raise ValueError(
                    f"The model to restore has a line height of {restore_data_params.line_height_}"
                    f" but a line height of {data.params().line_height_} is requested"
                )

            # create codec of the same type
            restore_codec = codec.__class__(
                restore_data_params['codec']['charset'])

            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(
                codec, shrink=not self._params.keep_loaded_codec)
            codec = restore_codec
            logger.info(
                f"Codec changes: {len(codec_changes[0])} deletions, {len(codec_changes[1])} appends"
            )
            # The actual weight/bias matrix will be changed after loading the old weights
            if not any(codec_changes):
                codec_changes = None  # No codec changes
        else:
            codec_changes = None

        model_params.classes = codec.size()
        data.params().codec = codec
        logger.info(f"CODEC: {codec.charset}")

        if self._params.preload_training:
            # preload after codec was created
            data.preload(progress_bar=self._params.progress_bar)
            train_pipeline = data.get_train_data()

        if self._params.current_stage == 0:
            super(Trainer, self).train(
                callbacks=callbacks,
                warmstart_fn=partial(WarmstarterWithCodecAdaption,
                                     codec_changes=codec_changes),
            )

        if self._params.data_aug_retrain_on_original and self._params.scenario_params.data_params.data_aug_params.to_abs(
        ) > 0:
            logger.info("Starting training on original data only")
            if self._params.current_stage == 0:
                self._params.current_epoch = 0
                self._params.current_stage = 1
                self._params.early_stopping_params.current_ = 1  # CER = 100% as initial value
                self._params.early_stopping_params.n_ = 0

            # Remove data augmenter
            self._data.params().pre_processors_.sample_processors = [
                p
                for p in self._data.params().pre_processors_.sample_processors
                if p.name != AugmentationProcessor.__name__
            ]
            # Remove augmented samples if 'preloaded"
            if isinstance(train_pipeline, RawDataPipeline):
                train_pipeline.samples = [
                    s for s in train_pipeline.samples
                    if not s.meta.get('augmented', False)
                ]

            logger.info(
                f"Training on {len(train_pipeline.create_data_generator())} samples."
            )

            super(Trainer, self).setup_steps_per_epoch()

            # replace callbacks that require steps per epoch as parameter
            tb_callback = next(cb for cb in self._callbacks
                               if isinstance(cb, TensorBoardCallback))
            tb_callback.steps_per_epoch = self._steps_per_epoch
            i = next(i for i, cb in enumerate(self._callbacks)
                     if isinstance(cb, TrainParamsLoggerCallback))
            del self._callbacks[i]
            self._callbacks.insert(i,
                                   self.create_train_params_logger_callback())

            super(Trainer, self).fit()

        logger.info("Training finished")
Esempio n. 11
0
    def train(self, auto_compute_codec=False, progress_bar=False, training_callback=ConsoleTrainingCallback()):
        """ Launch the training

        Parameters
        ----------
        auto_compute_codec : bool
            Compute the codec automatically based on the provided ground truth.
            Else provide a codec using a whitelist (faster).

        progress_bar : bool
            Show or hide any progress bar

        training_callback : TrainingCallback
            Callback for the training process (e.g., for displaying the current cer, loss in the console)

        """
        with ExitStackWithPop() as exit_stack:
            checkpoint_params = self.checkpoint_params

            train_start_time = time.time() + self.checkpoint_params.total_time

            exit_stack.enter_context(self.dataset)
            if self.validation_dataset:
                exit_stack.enter_context(self.validation_dataset)

            # load training dataset
            if self.preload_training:
                new_dataset = self.dataset.to_raw_input_dataset(processes=checkpoint_params.processes, progress_bar=progress_bar)
                exit_stack.pop(self.dataset)
                self.dataset = new_dataset
                exit_stack.enter_context(self.dataset)

            # load validation dataset
            if self.validation_dataset and self.preload_validation:
                new_dataset = self.validation_dataset.to_raw_input_dataset(processes=checkpoint_params.processes, progress_bar=progress_bar)
                exit_stack.pop(self.validation_dataset)
                self.validation_dataset = new_dataset
                exit_stack.enter_context(self.validation_dataset)

            # compute the codec
            if self.codec:
                codec = self.codec
            else:
                if len(self.codec_whitelist) == 0 or auto_compute_codec:
                    codec = Codec.from_input_dataset([self.dataset, self.validation_dataset],
                                                     whitelist=self.codec_whitelist, progress_bar=progress_bar)
                else:
                    codec = Codec.from_texts([], whitelist=self.codec_whitelist)

            # create backend
            network_params = checkpoint_params.model.network
            network_params.features = checkpoint_params.model.line_height
            if self.weights:
                # if we load the weights, take care of codec changes as-well
                ckpt = Checkpoint(self.weights + '.json', auto_update=self.auto_update_checkpoints)
                restore_checkpoint_params = ckpt.checkpoint
                restore_model_params = restore_checkpoint_params.model

                # checks
                if checkpoint_params.model.line_height != network_params.features:
                    raise Exception("The model to restore has a line height of {} but a line height of {} is requested".format(
                        network_params.features, checkpoint_params.model.line_height
                    ))

                # create codec of the same type
                restore_codec = codec.__class__(restore_model_params.codec.charset)

                # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
                codec_changes = restore_codec.align(codec, shrink=not self.keep_loaded_codec)
                codec = restore_codec
                print("Codec changes: {} deletions, {} appends".format(len(codec_changes[0]), len(codec_changes[1])))
                # The actual weight/bias matrix will be changed after loading the old weights
                if all([c == 0 for c in codec_changes]):
                    codec_changes = None  # No codec changes
            else:
                codec_changes = None

            # store the new codec
            network_params.classes = len(codec)
            checkpoint_params.model.codec.charset[:] = codec.charset
            print("CODEC: {}".format(codec.charset))

            backend = create_backend_from_checkpoint(
                checkpoint_params=checkpoint_params,
                processes=checkpoint_params.processes,
            )
            train_net = backend.create_net(codec, graph_type="train",
                                           checkpoint_to_load=Checkpoint(self.weights) if self.weights else None,
                                           batch_size=checkpoint_params.batch_size, codec_changes=codec_changes)

            if checkpoint_params.current_stage == 0:
                self._run_train(train_net, train_start_time, progress_bar, self.dataset, self.validation_dataset, training_callback)

            if checkpoint_params.data_aug_retrain_on_original and self.data_augmenter and self.n_augmentations != 0:
                print("Starting training on original data only")
                if checkpoint_params.current_stage == 0:
                    checkpoint_params.current_stage = 1
                    checkpoint_params.iter = 0
                    checkpoint_params.early_stopping_best_at_iter = 0
                    checkpoint_params.early_stopping_best_cur_nbest = 0
                    checkpoint_params.early_stopping_best_accuracy = 0

                self.dataset.generate_only_non_augmented = True  # this is the important line!
                self._run_train(train_net, train_start_time, progress_bar, self.dataset, self.validation_dataset, training_callback)
Esempio n. 12
0
    def train(self, auto_compute_codec=False, progress_bar=False):
        """ Launch the training

        Parameters
        ----------
        progress_bar : bool
            Show or hide any progress bar

        """
        checkpoint_params = self.checkpoint_params

        train_start_time = time.time() + self.checkpoint_params.total_time

        # load training dataset
        if self.preload_training:
            self.dataset.preload(processes=checkpoint_params.processes, progress_bar=progress_bar)

        # load validation dataset
        if self.validation_dataset and self.preload_validation:
            self.validation_dataset.preload(processes=checkpoint_params.processes, progress_bar=progress_bar)

        # compute the codec
        if self.codec:
            codec = self.codec
        else:
            if len(self.codec_whitelist) == 0 or auto_compute_codec:
                codec = Codec.from_input_dataset([self.dataset, self.validation_dataset],
                                                 whitelist=self.codec_whitelist, progress_bar=True)
            else:
                codec = Codec.from_texts([], whitelist=self.codec_whitelist)

        # create backend
        network_params = checkpoint_params.model.network
        network_params.features = checkpoint_params.model.line_height
        network_params.classes = len(codec)
        if self.weights:
            # if we load the weights, take care of codec changes as-well
            ckpt = Checkpoint(self.weights + '.json', auto_update=self.auto_update_checkpoints)
            restore_checkpoint_params = ckpt.checkpoint
            restore_model_params = restore_checkpoint_params.model

            # checks
            if checkpoint_params.model.line_height != network_params.features:
                raise Exception("The model to restore has a line height of {} but a line height of {} is requested".format(
                    network_params.features, checkpoint_params.model.line_height
                ))

            # create codec of the same type
            restore_codec = codec.__class__(restore_model_params.codec.charset)
            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(codec)
            codec = restore_codec
            print("Codec changes: {} deletions, {} appends".format(len(codec_changes[0]), len(codec_changes[1])))
            # The actual weight/bias matrix will be changed after loading the old weights
            if all([c == 0 for c in codec_changes]):
                codec_changes = None  # No codec changes
        else:
            codec_changes = None

        # store the new codec
        checkpoint_params.model.codec.charset[:] = codec.charset
        print("CODEC: {}".format(codec.charset))

        backend = create_backend_from_proto(network_params,
                                            weights=self.weights,
                                            )
        train_net = backend.create_net(self.dataset, codec, restore=None, weights=self.weights, graph_type="train", batch_size=checkpoint_params.batch_size)
        test_net = backend.create_net(self.validation_dataset, codec, restore=None, weights=self.weights, graph_type="test", batch_size=checkpoint_params.batch_size)
        if codec_changes:
            # only required on one net, since the other shares the same variables
            train_net.realign_model_labels(*codec_changes)

        train_net.prepare()
        test_net.prepare()

        if checkpoint_params.current_stage == 0:
            self._run_train(train_net, test_net, codec, train_start_time, progress_bar)

        if checkpoint_params.data_aug_retrain_on_original and self.dataset.data_augmenter and self.dataset.data_augmentation_amount > 0:
            print("Starting training on original data only")
            if checkpoint_params.current_stage == 0:
                checkpoint_params.current_stage = 1
                checkpoint_params.iter = 0
                checkpoint_params.early_stopping_best_at_iter = 0
                checkpoint_params.early_stopping_best_cur_nbest = 0
                checkpoint_params.early_stopping_best_accuracy = 0

            self.dataset.generate_only_non_augmented = True  # this is the important line!
            train_net.prepare()
            test_net.prepare()
            self._run_train(train_net, test_net, codec, train_start_time, progress_bar)

        train_net.prepare()  # reset the state
        test_net.prepare()   # to prevent blocking of tensorflow on shutdown