Example #1
0
    def train(self, progress_bar=False):
        checkpoint_params = self.checkpoint_params

        train_start_time = time.time() + self.checkpoint_params.total_time

        self.dataset.load_samples(processes=1, progress_bar=progress_bar)
        datas, txts = self.dataset.train_samples(skip_empty=checkpoint_params.skip_invalid_gt)
        if len(datas) == 0:
            raise Exception("Empty dataset is not allowed. Check if the data is at the correct location")

        if self.validation_dataset:
            self.validation_dataset.load_samples(processes=1, progress_bar=progress_bar)
            validation_datas, validation_txts = self.validation_dataset.train_samples(skip_empty=checkpoint_params.skip_invalid_gt)
            if len(validation_datas) == 0:
                raise Exception("Validation dataset is empty. Provide valid validation data for early stopping.")
        else:
            validation_datas, validation_txts = [], []


        # preprocessing steps
        texts = self.txt_preproc.apply(txts, processes=checkpoint_params.processes, progress_bar=progress_bar)
        datas = self.data_preproc.apply(datas, processes=checkpoint_params.processes, progress_bar=progress_bar)
        validation_txts = self.txt_preproc.apply(validation_txts, processes=checkpoint_params.processes, progress_bar=progress_bar)
        validation_datas = self.data_preproc.apply(validation_datas, processes=checkpoint_params.processes, progress_bar=progress_bar)

        # compute the codec
        codec = self.codec if self.codec else Codec.from_texts(texts, whitelist=self.codec_whitelist)

        # data augmentation on preprocessed data
        if self.data_augmenter:
            datas, texts = self.data_augmenter.augment_datas(datas, texts, n_augmentations=self.n_augmentations,
                                                             processes=checkpoint_params.processes, progress_bar=progress_bar)

            # TODO: validation data augmentation
            # validation_datas, validation_txts = self.data_augmenter.augment_datas(validation_datas, validation_txts, n_augmentations=0,
            #                                                  processes=checkpoint_params.processes, progress_bar=progress_bar)

        # create backend
        network_params = checkpoint_params.model.network
        network_params.features = checkpoint_params.model.line_height
        network_params.classes = len(codec)
        if self.weights:
            # if we load the weights, take care of codec changes as-well
            with open(self.weights + '.json', 'r') as f:
                restore_checkpoint_params = json_format.Parse(f.read(), CheckpointParams())
                restore_model_params = restore_checkpoint_params.model

            # checks
            if checkpoint_params.model.line_height != network_params.features:
                raise Exception("The model to restore has a line height of {} but a line height of {} is requested".format(
                    network_params.features, checkpoint_params.model.line_height
                ))

            # create codec of the same type
            restore_codec = codec.__class__(restore_model_params.codec.charset)
            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(codec)
            codec = restore_codec
            print("Codec changes: {} deletions, {} appends".format(len(codec_changes[0]), len(codec_changes[1])))
            # The actual weight/bias matrix will be changed after loading the old weights
        else:
            codec_changes = None

        # store the new codec
        checkpoint_params.model.codec.charset[:] = codec.charset
        print("CODEC: {}".format(codec.charset))

        # compute the labels with (new/current) codec
        labels = [codec.encode(txt) for txt in texts]

        backend = create_backend_from_proto(network_params,
                                            weights=self.weights,
                                            )
        backend.set_train_data(datas, labels)
        backend.set_prediction_data(validation_datas)
        if codec_changes:
            backend.realign_model_labels(*codec_changes)
        backend.prepare(train=True)

        loss_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.loss_stats)
        ler_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.ler_stats)
        dt_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.dt_stats)

        early_stopping_enabled = self.validation_dataset is not None \
                                 and checkpoint_params.early_stopping_frequency > 0 \
                                 and checkpoint_params.early_stopping_nbest > 1
        early_stopping_best_accuracy = checkpoint_params.early_stopping_best_accuracy
        early_stopping_best_cur_nbest = checkpoint_params.early_stopping_best_cur_nbest
        early_stopping_best_at_iter = checkpoint_params.early_stopping_best_at_iter

        early_stopping_predictor = Predictor(codec=codec, text_postproc=self.txt_postproc,
                                             backend=backend)

        # Start the actual training
        # ====================================================================================

        iter = checkpoint_params.iter

        # helper function to write a checkpoint
        def make_checkpoint(base_dir, prefix, version=None):
            if version:
                checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{}.ckpt".format(prefix, version)))
            else:
                checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{:08d}.ckpt".format(prefix, iter + 1)))
            print("Storing checkpoint to '{}'".format(checkpoint_path))
            backend.save_checkpoint(checkpoint_path)
            checkpoint_params.iter = iter
            checkpoint_params.loss_stats[:] = loss_stats.values
            checkpoint_params.ler_stats[:] = ler_stats.values
            checkpoint_params.dt_stats[:] = dt_stats.values
            checkpoint_params.total_time = time.time() - train_start_time
            checkpoint_params.early_stopping_best_accuracy = early_stopping_best_accuracy
            checkpoint_params.early_stopping_best_cur_nbest = early_stopping_best_cur_nbest
            checkpoint_params.early_stopping_best_at_iter = early_stopping_best_at_iter

            with open(checkpoint_path + ".json", 'w') as f:
                f.write(json_format.MessageToJson(checkpoint_params))

            return checkpoint_path

        try:
            last_checkpoint = None

            # Training loop, can be interrupted by early stopping
            for iter in range(iter, checkpoint_params.max_iters):
                checkpoint_params.iter = iter

                iter_start_time = time.time()
                result = backend.train_step(checkpoint_params.batch_size)

                if not np.isfinite(result['loss']):
                    print("Error: Loss is not finite! Trying to restart from last checkpoint.")
                    if not last_checkpoint:
                        raise Exception("No checkpoint written yet. Training must be stopped.")
                    else:
                        # reload also non trainable weights, such as solver-specific variables
                        backend.load_checkpoint_weights(last_checkpoint, restore_only_trainable=False)
                        continue

                loss_stats.push(result['loss'])
                ler_stats.push(result['ler'])

                dt_stats.push(time.time() - iter_start_time)

                if iter % checkpoint_params.display == 0:
                    pred_sentence = self.txt_postproc.apply("".join(codec.decode(result["decoded"][0])))
                    gt_sentence = self.txt_postproc.apply("".join(codec.decode(result["gt"][0])))
                    print("#{:08d}: loss={:.8f} ler={:.8f} dt={:.8f}s".format(iter, loss_stats.mean(), ler_stats.mean(), dt_stats.mean()))
                    print(" PRED: '{}'".format(pred_sentence))
                    print(" TRUE: '{}'".format(gt_sentence))

                if (iter + 1) % checkpoint_params.checkpoint_frequency == 0:
                    last_checkpoint = make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix)

                if early_stopping_enabled and (iter + 1) % checkpoint_params.early_stopping_frequency == 0:
                    print("Checking early stopping model")

                    out = early_stopping_predictor.predict_raw(validation_datas, batch_size=checkpoint_params.batch_size,
                                                               progress_bar=progress_bar, apply_preproc=False)
                    pred_texts = [d.sentence for d in out]
                    result = Evaluator.evaluate(gt_data=validation_txts, pred_data=pred_texts, progress_bar=progress_bar)
                    accuracy = 1 - result["avg_ler"]

                    if accuracy > early_stopping_best_accuracy:
                        early_stopping_best_accuracy = accuracy
                        early_stopping_best_cur_nbest = 1
                        early_stopping_best_at_iter = iter + 1
                        # overwrite as best model
                        last_checkpoint = make_checkpoint(
                            checkpoint_params.early_stopping_best_model_output_dir,
                            prefix="",
                            version=checkpoint_params.early_stopping_best_model_prefix,
                        )
                        print("Found better model with accuracy of {:%}".format(early_stopping_best_accuracy))
                    else:
                        early_stopping_best_cur_nbest += 1
                        print("No better model found. Currently accuracy of {:%} at iter {} (remaining nbest = {})".
                              format(early_stopping_best_accuracy, early_stopping_best_at_iter,
                                     checkpoint_params.early_stopping_nbest - early_stopping_best_cur_nbest))

                    if accuracy > 0 and early_stopping_best_cur_nbest >= checkpoint_params.early_stopping_nbest:
                        print("Early stopping now.")
                        break

        except KeyboardInterrupt as e:
            print("Storing interrupted checkpoint")
            make_checkpoint(checkpoint_params.output_dir,
                            checkpoint_params.output_model_prefix,
                            "interrupted")
            raise e

        print("Total time {}s for {} iterations.".format(time.time() - train_start_time, iter))
Example #2
0
    def train(self, auto_compute_codec=False, progress_bar=False):
        """ Launch the training

        Parameters
        ----------
        auto_compute_codec : bool
            Compute the codec automatically based on the provided ground truth.
            Else provide a codec using a whitelist (faster).

        progress_bar : bool
            Show or hide any progress bar

        """
        checkpoint_params = self.checkpoint_params

        train_start_time = time.time() + self.checkpoint_params.total_time

        # load training dataset
        if self.preload_training:
            self.dataset.preload(processes=checkpoint_params.processes,
                                 progress_bar=progress_bar)

        # load validation dataset
        if self.validation_dataset and self.preload_validation:
            self.validation_dataset.preload(
                processes=checkpoint_params.processes,
                progress_bar=progress_bar)

        # compute the codec
        if self.codec:
            codec = self.codec
        else:
            if len(self.codec_whitelist) == 0 or auto_compute_codec:
                codec = Codec.from_input_dataset(
                    [self.dataset, self.validation_dataset],
                    whitelist=self.codec_whitelist,
                    progress_bar=progress_bar)
            else:
                codec = Codec.from_texts([], whitelist=self.codec_whitelist)

        # create backend
        network_params = checkpoint_params.model.network
        network_params.features = checkpoint_params.model.line_height
        network_params.classes = len(codec)
        if self.weights:
            # if we load the weights, take care of codec changes as-well
            ckpt = Checkpoint(self.weights + '.json',
                              auto_update=self.auto_update_checkpoints)
            restore_checkpoint_params = ckpt.checkpoint
            restore_model_params = restore_checkpoint_params.model

            # checks
            if checkpoint_params.model.line_height != network_params.features:
                raise Exception(
                    "The model to restore has a line height of {} but a line height of {} is requested"
                    .format(network_params.features,
                            checkpoint_params.model.line_height))

            # create codec of the same type
            restore_codec = codec.__class__(restore_model_params.codec.charset)
            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(codec)
            codec = restore_codec
            print("Codec changes: {} deletions, {} appends".format(
                len(codec_changes[0]), len(codec_changes[1])))
            # The actual weight/bias matrix will be changed after loading the old weights
            if all([c == 0 for c in codec_changes]):
                codec_changes = None  # No codec changes
        else:
            codec_changes = None

        # store the new codec
        checkpoint_params.model.codec.charset[:] = codec.charset
        print("CODEC: {}".format(codec.charset))

        backend = create_backend_from_proto(
            network_params,
            weights=self.weights,
        )
        train_net = backend.create_net(self.dataset,
                                       codec,
                                       restore=None,
                                       weights=self.weights,
                                       graph_type="train",
                                       batch_size=checkpoint_params.batch_size)
        test_net = backend.create_net(self.validation_dataset,
                                      codec,
                                      restore=None,
                                      weights=self.weights,
                                      graph_type="test",
                                      batch_size=checkpoint_params.batch_size)
        if codec_changes:
            # only required on one net, since the other shares the same variables
            train_net.realign_model_labels(*codec_changes)

        train_net.prepare()
        test_net.prepare()

        if checkpoint_params.current_stage == 0:
            self._run_train(train_net, test_net, codec, train_start_time,
                            progress_bar)

        if checkpoint_params.data_aug_retrain_on_original and self.dataset.data_augmenter and self.dataset.data_augmentation_amount > 0:
            print("Starting training on original data only")
            if checkpoint_params.current_stage == 0:
                checkpoint_params.current_stage = 1
                checkpoint_params.iter = 0
                checkpoint_params.early_stopping_best_at_iter = 0
                checkpoint_params.early_stopping_best_cur_nbest = 0
                checkpoint_params.early_stopping_best_accuracy = 0

            self.dataset.generate_only_non_augmented = True  # this is the important line!
            train_net.prepare()
            test_net.prepare()
            self._run_train(train_net, test_net, codec, train_start_time,
                            progress_bar)

        train_net.prepare()  # reset the state
        test_net.prepare()  # to prevent blocking of tensorflow on shutdown
Example #3
0
    def train(self, progress_bar=False):
        """ Launch the training

        Parameters
        ----------
        progress_bar : bool
            Show or hide any progress bar

        """
        checkpoint_params = self.checkpoint_params

        train_start_time = time.time() + self.checkpoint_params.total_time

        self.dataset.load_samples(processes=1, progress_bar=progress_bar)
        datas, txts = self.dataset.train_samples(
            skip_empty=checkpoint_params.skip_invalid_gt)
        if len(datas) == 0:
            raise Exception(
                "Empty dataset is not allowed. Check if the data is at the correct location"
            )

        if self.validation_dataset:
            self.validation_dataset.load_samples(processes=1,
                                                 progress_bar=progress_bar)
            validation_datas, validation_txts = self.validation_dataset.train_samples(
                skip_empty=checkpoint_params.skip_invalid_gt)
            if len(validation_datas) == 0:
                raise Exception(
                    "Validation dataset is empty. Provide valid validation data for early stopping."
                )
        else:
            validation_datas, validation_txts = [], []

        # preprocessing steps
        texts = self.txt_preproc.apply(txts,
                                       processes=checkpoint_params.processes,
                                       progress_bar=progress_bar)
        datas, params = [
            list(a) for a in zip(
                *self.data_preproc.apply(datas,
                                         processes=checkpoint_params.processes,
                                         progress_bar=progress_bar))
        ]
        validation_txts = self.txt_preproc.apply(
            validation_txts,
            processes=checkpoint_params.processes,
            progress_bar=progress_bar)
        validation_data_params = self.data_preproc.apply(
            validation_datas,
            processes=checkpoint_params.processes,
            progress_bar=progress_bar)

        # compute the codec
        codec = self.codec if self.codec else Codec.from_texts(
            texts, whitelist=self.codec_whitelist)

        # store original data in case data augmentation is used with a second step
        original_texts = texts
        original_datas = datas

        # data augmentation on preprocessed data
        if self.data_augmenter:
            datas, texts = self.data_augmenter.augment_datas(
                datas,
                texts,
                n_augmentations=self.n_augmentations,
                processes=checkpoint_params.processes,
                progress_bar=progress_bar)

            # TODO: validation data augmentation
            # validation_datas, validation_txts = self.data_augmenter.augment_datas(validation_datas, validation_txts, n_augmentations=0,
            #                                                  processes=checkpoint_params.processes, progress_bar=progress_bar)

        # create backend
        network_params = checkpoint_params.model.network
        network_params.features = checkpoint_params.model.line_height
        network_params.classes = len(codec)
        if self.weights:
            # if we load the weights, take care of codec changes as-well
            ckpt = Checkpoint(self.weights + '.json',
                              auto_update=self.auto_update_checkpoints)
            restore_checkpoint_params = ckpt.checkpoint
            restore_model_params = restore_checkpoint_params.model

            # checks
            if checkpoint_params.model.line_height != network_params.features:
                raise Exception(
                    "The model to restore has a line height of {} but a line height of {} is requested"
                    .format(network_params.features,
                            checkpoint_params.model.line_height))

            # create codec of the same type
            restore_codec = codec.__class__(restore_model_params.codec.charset)
            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(codec)
            codec = restore_codec
            print("Codec changes: {} deletions, {} appends".format(
                len(codec_changes[0]), len(codec_changes[1])))
            # The actual weight/bias matrix will be changed after loading the old weights
            if all([c == 0 for c in codec_changes]):
                codec_changes = None  # No codec changes
        else:
            codec_changes = None

        # store the new codec
        checkpoint_params.model.codec.charset[:] = codec.charset
        print("CODEC: {}".format(codec.charset))

        # compute the labels with (new/current) codec
        labels = [codec.encode(txt) for txt in texts]

        backend = create_backend_from_proto(
            network_params,
            weights=self.weights,
        )
        train_net = backend.create_net(restore=None,
                                       weights=self.weights,
                                       graph_type="train",
                                       batch_size=checkpoint_params.batch_size)
        test_net = backend.create_net(restore=None,
                                      weights=self.weights,
                                      graph_type="test",
                                      batch_size=checkpoint_params.batch_size)
        train_net.set_data(datas, labels)
        test_net.set_data(validation_datas, validation_txts)
        if codec_changes:
            # only required on one net, since the other shares the same variables
            train_net.realign_model_labels(*codec_changes)

        train_net.prepare()
        test_net.prepare()

        if checkpoint_params.current_stage == 0:
            self._run_train(train_net, test_net, codec, validation_data_params,
                            train_start_time, progress_bar)

        if checkpoint_params.data_aug_retrain_on_original and self.data_augmenter and self.n_augmentations > 0:
            print("Starting training on original data only")
            if checkpoint_params.current_stage == 0:
                checkpoint_params.current_stage = 1
                checkpoint_params.iter = 0
                checkpoint_params.early_stopping_best_at_iter = 0
                checkpoint_params.early_stopping_best_cur_nbest = 0
                checkpoint_params.early_stopping_best_accuracy = 0

            train_net.set_data(original_datas,
                               [codec.encode(txt) for txt in original_texts])
            test_net.set_data(validation_datas, validation_txts)
            train_net.prepare()
            test_net.prepare()
            self._run_train(train_net, test_net, codec, validation_data_params,
                            train_start_time, progress_bar)
Example #4
0
    def train(self, callbacks=None, **kwargs):
        callbacks = callbacks if callbacks else []

        # load preloaded dataset
        data: Data = self._data
        model_params: ModelParams = self.scenario.params.model_params

        train_pipeline = data.get_train_data()
        if len(train_pipeline.create_data_generator()) == 0:
            raise ValueError("Training dataset is empty.")

        if data.params().val:
            val_pipeline = data.get_val_data()
            if len(val_pipeline.create_data_generator()) == 0:
                raise ValueError(
                    "Validation dataset is empty. Provide valid validation data for early stopping."
                )
        else:
            val_pipeline = None

        if self._params.preload_training:
            # preload before codec was created (not all processors can be applied, yet)
            data.preload(progress_bar=self._params.progress_bar)
            train_pipeline = data.get_train_data()
            if val_pipeline:
                val_pipeline = data.get_val_data()

        # compute the codec
        codec = data.params().codec
        if not codec:
            if len(self._params.codec_whitelist
                   ) == 0 or self._params.auto_compute_codec:
                codec = Codec.from_input_dataset(
                    filter(lambda x: x, [train_pipeline, val_pipeline]),
                    whitelist=self._params.codec_whitelist,
                    progress_bar=self._params.progress_bar)
            else:
                codec = Codec.from_texts(
                    [], whitelist=self._params.codec_whitelist)

        data.params().codec = codec
        data.params(
        ).downscale_factor_ = model_params.compute_downscale_factor()
        model_params.classes = codec.size()

        if self.checkpoint:
            # if we load the weights, take care of codec changes as-well
            restore_checkpoint_params = self.checkpoint.dict
            restore_data_params = restore_checkpoint_params['scenario_params'][
                'data_params']

            # checks
            if data.params(
            ).line_height_ != restore_data_params['line_height_']:
                raise ValueError(
                    f"The model to restore has a line height of {restore_data_params.line_height_}"
                    f" but a line height of {data.params().line_height_} is requested"
                )

            # create codec of the same type
            restore_codec = codec.__class__(
                restore_data_params['codec']['charset'])

            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(
                codec, shrink=not self._params.keep_loaded_codec)
            codec = restore_codec
            logger.info(
                f"Codec changes: {len(codec_changes[0])} deletions, {len(codec_changes[1])} appends"
            )
            # The actual weight/bias matrix will be changed after loading the old weights
            if not any(codec_changes):
                codec_changes = None  # No codec changes
        else:
            codec_changes = None

        model_params.classes = codec.size()
        data.params().codec = codec
        logger.info(f"CODEC: {codec.charset}")

        if self._params.preload_training:
            # preload after codec was created
            data.preload(progress_bar=self._params.progress_bar)
            train_pipeline = data.get_train_data()

        if self._params.current_stage == 0:
            super(Trainer, self).train(
                callbacks=callbacks,
                warmstart_fn=partial(WarmstarterWithCodecAdaption,
                                     codec_changes=codec_changes),
            )

        if self._params.data_aug_retrain_on_original and self._params.scenario_params.data_params.data_aug_params.to_abs(
        ) > 0:
            logger.info("Starting training on original data only")
            if self._params.current_stage == 0:
                self._params.current_epoch = 0
                self._params.current_stage = 1
                self._params.early_stopping_params.current_ = 1  # CER = 100% as initial value
                self._params.early_stopping_params.n_ = 0

            # Remove data augmenter
            self._data.params().pre_processors_.sample_processors = [
                p
                for p in self._data.params().pre_processors_.sample_processors
                if p.name != AugmentationProcessor.__name__
            ]
            # Remove augmented samples if 'preloaded"
            if isinstance(train_pipeline, RawDataPipeline):
                train_pipeline.samples = [
                    s for s in train_pipeline.samples
                    if not s.meta.get('augmented', False)
                ]

            logger.info(
                f"Training on {len(train_pipeline.create_data_generator())} samples."
            )

            super(Trainer, self).setup_steps_per_epoch()

            # replace callbacks that require steps per epoch as parameter
            tb_callback = next(cb for cb in self._callbacks
                               if isinstance(cb, TensorBoardCallback))
            tb_callback.steps_per_epoch = self._steps_per_epoch
            i = next(i for i, cb in enumerate(self._callbacks)
                     if isinstance(cb, TrainParamsLoggerCallback))
            del self._callbacks[i]
            self._callbacks.insert(i,
                                   self.create_train_params_logger_callback())

            super(Trainer, self).fit()

        logger.info("Training finished")
Example #5
0
    def train(self, auto_compute_codec=False, progress_bar=False, training_callback=ConsoleTrainingCallback()):
        """ Launch the training

        Parameters
        ----------
        auto_compute_codec : bool
            Compute the codec automatically based on the provided ground truth.
            Else provide a codec using a whitelist (faster).

        progress_bar : bool
            Show or hide any progress bar

        training_callback : TrainingCallback
            Callback for the training process (e.g., for displaying the current cer, loss in the console)

        """
        with ExitStackWithPop() as exit_stack:
            checkpoint_params = self.checkpoint_params

            train_start_time = time.time() + self.checkpoint_params.total_time

            exit_stack.enter_context(self.dataset)
            if self.validation_dataset:
                exit_stack.enter_context(self.validation_dataset)

            # load training dataset
            if self.preload_training:
                new_dataset = self.dataset.to_raw_input_dataset(processes=checkpoint_params.processes, progress_bar=progress_bar)
                exit_stack.pop(self.dataset)
                self.dataset = new_dataset
                exit_stack.enter_context(self.dataset)

            # load validation dataset
            if self.validation_dataset and self.preload_validation:
                new_dataset = self.validation_dataset.to_raw_input_dataset(processes=checkpoint_params.processes, progress_bar=progress_bar)
                exit_stack.pop(self.validation_dataset)
                self.validation_dataset = new_dataset
                exit_stack.enter_context(self.validation_dataset)

            # compute the codec
            if self.codec:
                codec = self.codec
            else:
                if len(self.codec_whitelist) == 0 or auto_compute_codec:
                    codec = Codec.from_input_dataset([self.dataset, self.validation_dataset],
                                                     whitelist=self.codec_whitelist, progress_bar=progress_bar)
                else:
                    codec = Codec.from_texts([], whitelist=self.codec_whitelist)

            # create backend
            network_params = checkpoint_params.model.network
            network_params.features = checkpoint_params.model.line_height
            if self.weights:
                # if we load the weights, take care of codec changes as-well
                ckpt = Checkpoint(self.weights + '.json', auto_update=self.auto_update_checkpoints)
                restore_checkpoint_params = ckpt.checkpoint
                restore_model_params = restore_checkpoint_params.model

                # checks
                if checkpoint_params.model.line_height != network_params.features:
                    raise Exception("The model to restore has a line height of {} but a line height of {} is requested".format(
                        network_params.features, checkpoint_params.model.line_height
                    ))

                # create codec of the same type
                restore_codec = codec.__class__(restore_model_params.codec.charset)

                # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
                codec_changes = restore_codec.align(codec, shrink=not self.keep_loaded_codec)
                codec = restore_codec
                print("Codec changes: {} deletions, {} appends".format(len(codec_changes[0]), len(codec_changes[1])))
                # The actual weight/bias matrix will be changed after loading the old weights
                if all([c == 0 for c in codec_changes]):
                    codec_changes = None  # No codec changes
            else:
                codec_changes = None

            # store the new codec
            network_params.classes = len(codec)
            checkpoint_params.model.codec.charset[:] = codec.charset
            print("CODEC: {}".format(codec.charset))

            backend = create_backend_from_checkpoint(
                checkpoint_params=checkpoint_params,
                processes=checkpoint_params.processes,
            )
            train_net = backend.create_net(codec, graph_type="train",
                                           checkpoint_to_load=Checkpoint(self.weights) if self.weights else None,
                                           batch_size=checkpoint_params.batch_size, codec_changes=codec_changes)

            if checkpoint_params.current_stage == 0:
                self._run_train(train_net, train_start_time, progress_bar, self.dataset, self.validation_dataset, training_callback)

            if checkpoint_params.data_aug_retrain_on_original and self.data_augmenter and self.n_augmentations != 0:
                print("Starting training on original data only")
                if checkpoint_params.current_stage == 0:
                    checkpoint_params.current_stage = 1
                    checkpoint_params.iter = 0
                    checkpoint_params.early_stopping_best_at_iter = 0
                    checkpoint_params.early_stopping_best_cur_nbest = 0
                    checkpoint_params.early_stopping_best_accuracy = 0

                self.dataset.generate_only_non_augmented = True  # this is the important line!
                self._run_train(train_net, train_start_time, progress_bar, self.dataset, self.validation_dataset, training_callback)
Example #6
0
    def train(self, auto_compute_codec=False, progress_bar=False):
        """ Launch the training

        Parameters
        ----------
        progress_bar : bool
            Show or hide any progress bar

        """
        checkpoint_params = self.checkpoint_params

        train_start_time = time.time() + self.checkpoint_params.total_time

        # load training dataset
        if self.preload_training:
            self.dataset.preload(processes=checkpoint_params.processes, progress_bar=progress_bar)

        # load validation dataset
        if self.validation_dataset and self.preload_validation:
            self.validation_dataset.preload(processes=checkpoint_params.processes, progress_bar=progress_bar)

        # compute the codec
        if self.codec:
            codec = self.codec
        else:
            if len(self.codec_whitelist) == 0 or auto_compute_codec:
                codec = Codec.from_input_dataset([self.dataset, self.validation_dataset],
                                                 whitelist=self.codec_whitelist, progress_bar=True)
            else:
                codec = Codec.from_texts([], whitelist=self.codec_whitelist)

        # create backend
        network_params = checkpoint_params.model.network
        network_params.features = checkpoint_params.model.line_height
        network_params.classes = len(codec)
        if self.weights:
            # if we load the weights, take care of codec changes as-well
            ckpt = Checkpoint(self.weights + '.json', auto_update=self.auto_update_checkpoints)
            restore_checkpoint_params = ckpt.checkpoint
            restore_model_params = restore_checkpoint_params.model

            # checks
            if checkpoint_params.model.line_height != network_params.features:
                raise Exception("The model to restore has a line height of {} but a line height of {} is requested".format(
                    network_params.features, checkpoint_params.model.line_height
                ))

            # create codec of the same type
            restore_codec = codec.__class__(restore_model_params.codec.charset)
            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(codec)
            codec = restore_codec
            print("Codec changes: {} deletions, {} appends".format(len(codec_changes[0]), len(codec_changes[1])))
            # The actual weight/bias matrix will be changed after loading the old weights
            if all([c == 0 for c in codec_changes]):
                codec_changes = None  # No codec changes
        else:
            codec_changes = None

        # store the new codec
        checkpoint_params.model.codec.charset[:] = codec.charset
        print("CODEC: {}".format(codec.charset))

        backend = create_backend_from_proto(network_params,
                                            weights=self.weights,
                                            )
        train_net = backend.create_net(self.dataset, codec, restore=None, weights=self.weights, graph_type="train", batch_size=checkpoint_params.batch_size)
        test_net = backend.create_net(self.validation_dataset, codec, restore=None, weights=self.weights, graph_type="test", batch_size=checkpoint_params.batch_size)
        if codec_changes:
            # only required on one net, since the other shares the same variables
            train_net.realign_model_labels(*codec_changes)

        train_net.prepare()
        test_net.prepare()

        if checkpoint_params.current_stage == 0:
            self._run_train(train_net, test_net, codec, train_start_time, progress_bar)

        if checkpoint_params.data_aug_retrain_on_original and self.dataset.data_augmenter and self.dataset.data_augmentation_amount > 0:
            print("Starting training on original data only")
            if checkpoint_params.current_stage == 0:
                checkpoint_params.current_stage = 1
                checkpoint_params.iter = 0
                checkpoint_params.early_stopping_best_at_iter = 0
                checkpoint_params.early_stopping_best_cur_nbest = 0
                checkpoint_params.early_stopping_best_accuracy = 0

            self.dataset.generate_only_non_augmented = True  # this is the important line!
            train_net.prepare()
            test_net.prepare()
            self._run_train(train_net, test_net, codec, train_start_time, progress_bar)

        train_net.prepare()  # reset the state
        test_net.prepare()   # to prevent blocking of tensorflow on shutdown