Exemplo n.º 1
0
    def __init__(self,
                 checkpoint=None,
                 text_postproc=None,
                 data_preproc=None,
                 codec=None,
                 backend=None):
        self.backend = backend
        self.checkpoint = checkpoint
        self.codec = codec

        if checkpoint:
            if backend:
                raise Exception(
                    "Either a checkpoint or a backend can be provided")

            with open(checkpoint + '.json', 'r') as f:
                checkpoint_params = json_format.Parse(f.read(),
                                                      CheckpointParams())
                self.model_params = checkpoint_params.model

            self.network_params = self.model_params.network
            self.backend = create_backend_from_proto(self.network_params,
                                                     restore=self.checkpoint)
            self.text_postproc = text_postproc if text_postproc else text_processor_from_proto(
                self.model_params.text_postprocessor, "post")
            self.data_preproc = data_preproc if data_preproc else data_processor_from_proto(
                self.model_params.data_preprocessor)
        elif backend:
            self.model_params = None
            self.network_params = backend.network_proto
            self.text_postproc = text_postproc
            self.data_preproc = data_preproc
        else:
            raise Exception(
                "Either a checkpoint or a existing backend must be provided")
Exemplo n.º 2
0
    def __init__(self, checkpoint=None, text_postproc=None, data_preproc=None, codec=None, network=None,
                 batch_size=1, processes=1,
                 auto_update_checkpoints=True,
                 with_gt=False,
                 ):
        """ Predicting a dataset based on a trained model

        Parameters
        ----------
        checkpoint : str, optional
            filepath of the checkpoint of the network to load, alternatively you can directly use a loaded `network`
        text_postproc : TextProcessor, optional
            text processor to be applied on the predicted sentence for the final output.
            If loaded from a checkpoint the text processor will be loaded from it.
        data_preproc : DataProcessor, optional
            data processor (must be the same as of the trained model) to be applied to the input image.
            If loaded from a checkpoint the text processor will be loaded from it.
        codec : Codec, optional
            Codec of the deep net to use for decoding. This parameter is only required if a custom codec is used,
            or a `network` has been provided instead of a `checkpoint`
        network : ModelInterface, optional
            DNN instance to used. Alternatively you can provide a `checkpoint` to load a network.
        batch_size : int, optional
            Batch size to use for prediction
        processes : int, optional
            The number of processes to use for prediction
        auto_update_checkpoints : bool, optional
            Update old models automatically (this will change the checkpoint files)
        with_gt : bool, optional
            The prediction will also output the ground truth if available else None
        """
        self.network = network
        self.checkpoint = checkpoint
        self.processes = processes
        self.auto_update_checkpoints = auto_update_checkpoints
        self.with_gt = with_gt

        if checkpoint:
            if network:
                raise Exception("Either a checkpoint or a network can be provided")

            ckpt = Checkpoint(checkpoint, auto_update=self.auto_update_checkpoints)
            self.checkpoint = ckpt.ckpt_path
            checkpoint_params = ckpt.checkpoint
            self.model_params = checkpoint_params.model
            self.codec = codec if codec else Codec(self.model_params.codec.charset)

            self.network_params = self.model_params.network
            backend = create_backend_from_proto(self.network_params, restore=self.checkpoint, processes=processes)
            self.text_postproc = text_postproc if text_postproc else text_processor_from_proto(self.model_params.text_postprocessor, "post")
            self.data_preproc = data_preproc if data_preproc else data_processor_from_proto(self.model_params.data_preprocessor)
            self.network = backend.create_net(
                dataset=None,
                codec=self.codec,
                restore=self.checkpoint, weights=None, graph_type="predict", batch_size=batch_size)
        elif network:
            self.codec = codec
            self.model_params = None
            self.network_params = network.network_proto
            self.text_postproc = text_postproc
            self.data_preproc = data_preproc
            if not codec:
                raise Exception("A codec is required if preloaded network is used.")
        else:
            raise Exception("Either a checkpoint or a existing backend must be provided")

        self.out_to_in_trans = OutputToInputTransformer(self.data_preproc, self.network)
Exemplo n.º 3
0
    def train(self, progress_bar=False):
        checkpoint_params = self.checkpoint_params

        train_start_time = time.time() + self.checkpoint_params.total_time

        self.dataset.load_samples(processes=1, progress_bar=progress_bar)
        datas, txts = self.dataset.train_samples(skip_empty=checkpoint_params.skip_invalid_gt)
        if len(datas) == 0:
            raise Exception("Empty dataset is not allowed. Check if the data is at the correct location")

        if self.validation_dataset:
            self.validation_dataset.load_samples(processes=1, progress_bar=progress_bar)
            validation_datas, validation_txts = self.validation_dataset.train_samples(skip_empty=checkpoint_params.skip_invalid_gt)
            if len(validation_datas) == 0:
                raise Exception("Validation dataset is empty. Provide valid validation data for early stopping.")
        else:
            validation_datas, validation_txts = [], []


        # preprocessing steps
        texts = self.txt_preproc.apply(txts, processes=checkpoint_params.processes, progress_bar=progress_bar)
        datas = self.data_preproc.apply(datas, processes=checkpoint_params.processes, progress_bar=progress_bar)
        validation_txts = self.txt_preproc.apply(validation_txts, processes=checkpoint_params.processes, progress_bar=progress_bar)
        validation_datas = self.data_preproc.apply(validation_datas, processes=checkpoint_params.processes, progress_bar=progress_bar)

        # compute the codec
        codec = self.codec if self.codec else Codec.from_texts(texts, whitelist=self.codec_whitelist)

        # data augmentation on preprocessed data
        if self.data_augmenter:
            datas, texts = self.data_augmenter.augment_datas(datas, texts, n_augmentations=self.n_augmentations,
                                                             processes=checkpoint_params.processes, progress_bar=progress_bar)

            # TODO: validation data augmentation
            # validation_datas, validation_txts = self.data_augmenter.augment_datas(validation_datas, validation_txts, n_augmentations=0,
            #                                                  processes=checkpoint_params.processes, progress_bar=progress_bar)

        # create backend
        network_params = checkpoint_params.model.network
        network_params.features = checkpoint_params.model.line_height
        network_params.classes = len(codec)
        if self.weights:
            # if we load the weights, take care of codec changes as-well
            with open(self.weights + '.json', 'r') as f:
                restore_checkpoint_params = json_format.Parse(f.read(), CheckpointParams())
                restore_model_params = restore_checkpoint_params.model

            # checks
            if checkpoint_params.model.line_height != network_params.features:
                raise Exception("The model to restore has a line height of {} but a line height of {} is requested".format(
                    network_params.features, checkpoint_params.model.line_height
                ))

            # create codec of the same type
            restore_codec = codec.__class__(restore_model_params.codec.charset)
            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(codec)
            codec = restore_codec
            print("Codec changes: {} deletions, {} appends".format(len(codec_changes[0]), len(codec_changes[1])))
            # The actual weight/bias matrix will be changed after loading the old weights
        else:
            codec_changes = None

        # store the new codec
        checkpoint_params.model.codec.charset[:] = codec.charset
        print("CODEC: {}".format(codec.charset))

        # compute the labels with (new/current) codec
        labels = [codec.encode(txt) for txt in texts]

        backend = create_backend_from_proto(network_params,
                                            weights=self.weights,
                                            )
        backend.set_train_data(datas, labels)
        backend.set_prediction_data(validation_datas)
        if codec_changes:
            backend.realign_model_labels(*codec_changes)
        backend.prepare(train=True)

        loss_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.loss_stats)
        ler_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.ler_stats)
        dt_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.dt_stats)

        early_stopping_enabled = self.validation_dataset is not None \
                                 and checkpoint_params.early_stopping_frequency > 0 \
                                 and checkpoint_params.early_stopping_nbest > 1
        early_stopping_best_accuracy = checkpoint_params.early_stopping_best_accuracy
        early_stopping_best_cur_nbest = checkpoint_params.early_stopping_best_cur_nbest
        early_stopping_best_at_iter = checkpoint_params.early_stopping_best_at_iter

        early_stopping_predictor = Predictor(codec=codec, text_postproc=self.txt_postproc,
                                             backend=backend)

        # Start the actual training
        # ====================================================================================

        iter = checkpoint_params.iter

        # helper function to write a checkpoint
        def make_checkpoint(base_dir, prefix, version=None):
            if version:
                checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{}.ckpt".format(prefix, version)))
            else:
                checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{:08d}.ckpt".format(prefix, iter + 1)))
            print("Storing checkpoint to '{}'".format(checkpoint_path))
            backend.save_checkpoint(checkpoint_path)
            checkpoint_params.iter = iter
            checkpoint_params.loss_stats[:] = loss_stats.values
            checkpoint_params.ler_stats[:] = ler_stats.values
            checkpoint_params.dt_stats[:] = dt_stats.values
            checkpoint_params.total_time = time.time() - train_start_time
            checkpoint_params.early_stopping_best_accuracy = early_stopping_best_accuracy
            checkpoint_params.early_stopping_best_cur_nbest = early_stopping_best_cur_nbest
            checkpoint_params.early_stopping_best_at_iter = early_stopping_best_at_iter

            with open(checkpoint_path + ".json", 'w') as f:
                f.write(json_format.MessageToJson(checkpoint_params))

            return checkpoint_path

        try:
            last_checkpoint = None

            # Training loop, can be interrupted by early stopping
            for iter in range(iter, checkpoint_params.max_iters):
                checkpoint_params.iter = iter

                iter_start_time = time.time()
                result = backend.train_step(checkpoint_params.batch_size)

                if not np.isfinite(result['loss']):
                    print("Error: Loss is not finite! Trying to restart from last checkpoint.")
                    if not last_checkpoint:
                        raise Exception("No checkpoint written yet. Training must be stopped.")
                    else:
                        # reload also non trainable weights, such as solver-specific variables
                        backend.load_checkpoint_weights(last_checkpoint, restore_only_trainable=False)
                        continue

                loss_stats.push(result['loss'])
                ler_stats.push(result['ler'])

                dt_stats.push(time.time() - iter_start_time)

                if iter % checkpoint_params.display == 0:
                    pred_sentence = self.txt_postproc.apply("".join(codec.decode(result["decoded"][0])))
                    gt_sentence = self.txt_postproc.apply("".join(codec.decode(result["gt"][0])))
                    print("#{:08d}: loss={:.8f} ler={:.8f} dt={:.8f}s".format(iter, loss_stats.mean(), ler_stats.mean(), dt_stats.mean()))
                    print(" PRED: '{}'".format(pred_sentence))
                    print(" TRUE: '{}'".format(gt_sentence))

                if (iter + 1) % checkpoint_params.checkpoint_frequency == 0:
                    last_checkpoint = make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix)

                if early_stopping_enabled and (iter + 1) % checkpoint_params.early_stopping_frequency == 0:
                    print("Checking early stopping model")

                    out = early_stopping_predictor.predict_raw(validation_datas, batch_size=checkpoint_params.batch_size,
                                                               progress_bar=progress_bar, apply_preproc=False)
                    pred_texts = [d.sentence for d in out]
                    result = Evaluator.evaluate(gt_data=validation_txts, pred_data=pred_texts, progress_bar=progress_bar)
                    accuracy = 1 - result["avg_ler"]

                    if accuracy > early_stopping_best_accuracy:
                        early_stopping_best_accuracy = accuracy
                        early_stopping_best_cur_nbest = 1
                        early_stopping_best_at_iter = iter + 1
                        # overwrite as best model
                        last_checkpoint = make_checkpoint(
                            checkpoint_params.early_stopping_best_model_output_dir,
                            prefix="",
                            version=checkpoint_params.early_stopping_best_model_prefix,
                        )
                        print("Found better model with accuracy of {:%}".format(early_stopping_best_accuracy))
                    else:
                        early_stopping_best_cur_nbest += 1
                        print("No better model found. Currently accuracy of {:%} at iter {} (remaining nbest = {})".
                              format(early_stopping_best_accuracy, early_stopping_best_at_iter,
                                     checkpoint_params.early_stopping_nbest - early_stopping_best_cur_nbest))

                    if accuracy > 0 and early_stopping_best_cur_nbest >= checkpoint_params.early_stopping_nbest:
                        print("Early stopping now.")
                        break

        except KeyboardInterrupt as e:
            print("Storing interrupted checkpoint")
            make_checkpoint(checkpoint_params.output_dir,
                            checkpoint_params.output_model_prefix,
                            "interrupted")
            raise e

        print("Total time {}s for {} iterations.".format(time.time() - train_start_time, iter))
Exemplo n.º 4
0
    def train(self, auto_compute_codec=False, progress_bar=False):
        """ Launch the training

        Parameters
        ----------
        auto_compute_codec : bool
            Compute the codec automatically based on the provided ground truth.
            Else provide a codec using a whitelist (faster).

        progress_bar : bool
            Show or hide any progress bar

        """
        checkpoint_params = self.checkpoint_params

        train_start_time = time.time() + self.checkpoint_params.total_time

        # load training dataset
        if self.preload_training:
            self.dataset.preload(processes=checkpoint_params.processes,
                                 progress_bar=progress_bar)

        # load validation dataset
        if self.validation_dataset and self.preload_validation:
            self.validation_dataset.preload(
                processes=checkpoint_params.processes,
                progress_bar=progress_bar)

        # compute the codec
        if self.codec:
            codec = self.codec
        else:
            if len(self.codec_whitelist) == 0 or auto_compute_codec:
                codec = Codec.from_input_dataset(
                    [self.dataset, self.validation_dataset],
                    whitelist=self.codec_whitelist,
                    progress_bar=progress_bar)
            else:
                codec = Codec.from_texts([], whitelist=self.codec_whitelist)

        # create backend
        network_params = checkpoint_params.model.network
        network_params.features = checkpoint_params.model.line_height
        network_params.classes = len(codec)
        if self.weights:
            # if we load the weights, take care of codec changes as-well
            ckpt = Checkpoint(self.weights + '.json',
                              auto_update=self.auto_update_checkpoints)
            restore_checkpoint_params = ckpt.checkpoint
            restore_model_params = restore_checkpoint_params.model

            # checks
            if checkpoint_params.model.line_height != network_params.features:
                raise Exception(
                    "The model to restore has a line height of {} but a line height of {} is requested"
                    .format(network_params.features,
                            checkpoint_params.model.line_height))

            # create codec of the same type
            restore_codec = codec.__class__(restore_model_params.codec.charset)
            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(codec)
            codec = restore_codec
            print("Codec changes: {} deletions, {} appends".format(
                len(codec_changes[0]), len(codec_changes[1])))
            # The actual weight/bias matrix will be changed after loading the old weights
            if all([c == 0 for c in codec_changes]):
                codec_changes = None  # No codec changes
        else:
            codec_changes = None

        # store the new codec
        checkpoint_params.model.codec.charset[:] = codec.charset
        print("CODEC: {}".format(codec.charset))

        backend = create_backend_from_proto(
            network_params,
            weights=self.weights,
        )
        train_net = backend.create_net(self.dataset,
                                       codec,
                                       restore=None,
                                       weights=self.weights,
                                       graph_type="train",
                                       batch_size=checkpoint_params.batch_size)
        test_net = backend.create_net(self.validation_dataset,
                                      codec,
                                      restore=None,
                                      weights=self.weights,
                                      graph_type="test",
                                      batch_size=checkpoint_params.batch_size)
        if codec_changes:
            # only required on one net, since the other shares the same variables
            train_net.realign_model_labels(*codec_changes)

        train_net.prepare()
        test_net.prepare()

        if checkpoint_params.current_stage == 0:
            self._run_train(train_net, test_net, codec, train_start_time,
                            progress_bar)

        if checkpoint_params.data_aug_retrain_on_original and self.dataset.data_augmenter and self.dataset.data_augmentation_amount > 0:
            print("Starting training on original data only")
            if checkpoint_params.current_stage == 0:
                checkpoint_params.current_stage = 1
                checkpoint_params.iter = 0
                checkpoint_params.early_stopping_best_at_iter = 0
                checkpoint_params.early_stopping_best_cur_nbest = 0
                checkpoint_params.early_stopping_best_accuracy = 0

            self.dataset.generate_only_non_augmented = True  # this is the important line!
            train_net.prepare()
            test_net.prepare()
            self._run_train(train_net, test_net, codec, train_start_time,
                            progress_bar)

        train_net.prepare()  # reset the state
        test_net.prepare()  # to prevent blocking of tensorflow on shutdown
Exemplo n.º 5
0
    def train(self, progress_bar=False):
        """ Launch the training

        Parameters
        ----------
        progress_bar : bool
            Show or hide any progress bar

        """
        checkpoint_params = self.checkpoint_params

        train_start_time = time.time() + self.checkpoint_params.total_time

        self.dataset.load_samples(processes=1, progress_bar=progress_bar)
        datas, txts = self.dataset.train_samples(
            skip_empty=checkpoint_params.skip_invalid_gt)
        if len(datas) == 0:
            raise Exception(
                "Empty dataset is not allowed. Check if the data is at the correct location"
            )

        if self.validation_dataset:
            self.validation_dataset.load_samples(processes=1,
                                                 progress_bar=progress_bar)
            validation_datas, validation_txts = self.validation_dataset.train_samples(
                skip_empty=checkpoint_params.skip_invalid_gt)
            if len(validation_datas) == 0:
                raise Exception(
                    "Validation dataset is empty. Provide valid validation data for early stopping."
                )
        else:
            validation_datas, validation_txts = [], []

        # preprocessing steps
        texts = self.txt_preproc.apply(txts,
                                       processes=checkpoint_params.processes,
                                       progress_bar=progress_bar)
        datas, params = [
            list(a) for a in zip(
                *self.data_preproc.apply(datas,
                                         processes=checkpoint_params.processes,
                                         progress_bar=progress_bar))
        ]
        validation_txts = self.txt_preproc.apply(
            validation_txts,
            processes=checkpoint_params.processes,
            progress_bar=progress_bar)
        validation_data_params = self.data_preproc.apply(
            validation_datas,
            processes=checkpoint_params.processes,
            progress_bar=progress_bar)

        # compute the codec
        codec = self.codec if self.codec else Codec.from_texts(
            texts, whitelist=self.codec_whitelist)

        # store original data in case data augmentation is used with a second step
        original_texts = texts
        original_datas = datas

        # data augmentation on preprocessed data
        if self.data_augmenter:
            datas, texts = self.data_augmenter.augment_datas(
                datas,
                texts,
                n_augmentations=self.n_augmentations,
                processes=checkpoint_params.processes,
                progress_bar=progress_bar)

            # TODO: validation data augmentation
            # validation_datas, validation_txts = self.data_augmenter.augment_datas(validation_datas, validation_txts, n_augmentations=0,
            #                                                  processes=checkpoint_params.processes, progress_bar=progress_bar)

        # create backend
        network_params = checkpoint_params.model.network
        network_params.features = checkpoint_params.model.line_height
        network_params.classes = len(codec)
        if self.weights:
            # if we load the weights, take care of codec changes as-well
            ckpt = Checkpoint(self.weights + '.json',
                              auto_update=self.auto_update_checkpoints)
            restore_checkpoint_params = ckpt.checkpoint
            restore_model_params = restore_checkpoint_params.model

            # checks
            if checkpoint_params.model.line_height != network_params.features:
                raise Exception(
                    "The model to restore has a line height of {} but a line height of {} is requested"
                    .format(network_params.features,
                            checkpoint_params.model.line_height))

            # create codec of the same type
            restore_codec = codec.__class__(restore_model_params.codec.charset)
            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(codec)
            codec = restore_codec
            print("Codec changes: {} deletions, {} appends".format(
                len(codec_changes[0]), len(codec_changes[1])))
            # The actual weight/bias matrix will be changed after loading the old weights
            if all([c == 0 for c in codec_changes]):
                codec_changes = None  # No codec changes
        else:
            codec_changes = None

        # store the new codec
        checkpoint_params.model.codec.charset[:] = codec.charset
        print("CODEC: {}".format(codec.charset))

        # compute the labels with (new/current) codec
        labels = [codec.encode(txt) for txt in texts]

        backend = create_backend_from_proto(
            network_params,
            weights=self.weights,
        )
        train_net = backend.create_net(restore=None,
                                       weights=self.weights,
                                       graph_type="train",
                                       batch_size=checkpoint_params.batch_size)
        test_net = backend.create_net(restore=None,
                                      weights=self.weights,
                                      graph_type="test",
                                      batch_size=checkpoint_params.batch_size)
        train_net.set_data(datas, labels)
        test_net.set_data(validation_datas, validation_txts)
        if codec_changes:
            # only required on one net, since the other shares the same variables
            train_net.realign_model_labels(*codec_changes)

        train_net.prepare()
        test_net.prepare()

        if checkpoint_params.current_stage == 0:
            self._run_train(train_net, test_net, codec, validation_data_params,
                            train_start_time, progress_bar)

        if checkpoint_params.data_aug_retrain_on_original and self.data_augmenter and self.n_augmentations > 0:
            print("Starting training on original data only")
            if checkpoint_params.current_stage == 0:
                checkpoint_params.current_stage = 1
                checkpoint_params.iter = 0
                checkpoint_params.early_stopping_best_at_iter = 0
                checkpoint_params.early_stopping_best_cur_nbest = 0
                checkpoint_params.early_stopping_best_accuracy = 0

            train_net.set_data(original_datas,
                               [codec.encode(txt) for txt in original_texts])
            test_net.set_data(validation_datas, validation_txts)
            train_net.prepare()
            test_net.prepare()
            self._run_train(train_net, test_net, codec, validation_data_params,
                            train_start_time, progress_bar)
Exemplo n.º 6
0
    def __init__(self, checkpoint=None, text_postproc=None, data_preproc=None, codec=None, network=None,
                 batch_size=1, processes=1,
                 auto_update_checkpoints=True,
                 with_gt=False,
                 ):
        """ Predicting a dataset based on a trained model

        Parameters
        ----------
        checkpoint : str, optional
            filepath of the checkpoint of the network to load, alternatively you can directly use a loaded `network`
        text_postproc : TextProcessor, optional
            text processor to be applied on the predicted sentence for the final output.
            If loaded from a checkpoint the text processor will be loaded from it.
        data_preproc : DataProcessor, optional
            data processor (must be the same as of the trained model) to be applied to the input image.
            If loaded from a checkpoint the text processor will be loaded from it.
        codec : Codec, optional
            Codec of the deep net to use for decoding. This parameter is only required if a custom codec is used,
            or a `network` has been provided instead of a `checkpoint`
        network : ModelInterface, optional
            DNN instance to used. Alternatively you can provide a `checkpoint` to load a network.
        batch_size : int, optional
            Batch size to use for prediction
        processes : int, optional
            The number of processes to use for prediction
        auto_update_checkpoints : bool, optional
            Update old models automatically (this will change the checkpoint files)
        with_gt : bool, optional
            The prediction will also output the ground truth if available else None
        """
        self.network = network
        self.checkpoint = checkpoint
        self.processes = processes
        self.auto_update_checkpoints = auto_update_checkpoints
        self.with_gt = with_gt

        if checkpoint:
            if network:
                raise Exception("Either a checkpoint or a network can be provided")

            ckpt = Checkpoint(checkpoint, auto_update=self.auto_update_checkpoints)
            checkpoint_params = ckpt.checkpoint
            self.model_params = checkpoint_params.model
            self.codec = codec if codec else Codec(self.model_params.codec.charset)

            self.network_params = self.model_params.network
            backend = create_backend_from_proto(self.network_params, restore=self.checkpoint, processes=processes)
            self.text_postproc = text_postproc if text_postproc else text_processor_from_proto(self.model_params.text_postprocessor, "post")
            self.data_preproc = data_preproc if data_preproc else data_processor_from_proto(self.model_params.data_preprocessor)
            self.network = backend.create_net(
                dataset=None,
                codec=self.codec,
                restore=self.checkpoint, weights=None, graph_type="predict", batch_size=batch_size)
        elif network:
            self.codec = codec
            self.model_params = None
            self.network_params = network.network_proto
            self.text_postproc = text_postproc
            self.data_preproc = data_preproc
            if not codec:
                raise Exception("A codec is required if preloaded network is used.")
        else:
            raise Exception("Either a checkpoint or a existing backend must be provided")

        self.out_to_in_trans = OutputToInputTransformer(self.data_preproc, self.network)
Exemplo n.º 7
0
    def __init__(self,
                 checkpoint=None,
                 text_postproc=None,
                 data_preproc=None,
                 codec=None,
                 network=None,
                 batch_size=1,
                 processes=1):
        """ Predicting a dataset based on a trained model

        Parameters
        ----------
        checkpoint : str, optional
            filepath of the checkpoint of the network to load, alternatively you can directly use a loaded `network`
        text_postproc : TextProcessor, optional
            text processor to be applied on the predicted sentence for the final output.
            If loaded from a checkpoint the text processor will be loaded from it.
        data_preproc : DataProcessor, optional
            data processor (must be the same as of the trained model) to be applied to the input image.
            If loaded from a checkpoint the text processor will be loaded from it.
        codec : Codec, optional
            Codec of the deep net to use for decoding. This parameter is only required if a custom codec is used,
            or a `network` has been provided instead of a `checkpoint`
        network : ModelInterface, optional
            DNN instance to used. Alternatively you can provide a `checkpoint` to load a network.
        batch_size : int, optional
            Batch size to use for prediction
        processes : int, optional
            The number of processes to use for prediction
        """
        self.network = network
        self.checkpoint = checkpoint
        self.processes = processes

        if checkpoint:
            if network:
                raise Exception(
                    "Either a checkpoint or a network can be provided")

            with open(checkpoint + '.json', 'r') as f:
                checkpoint_params = json_format.Parse(f.read(),
                                                      CheckpointParams())
                self.model_params = checkpoint_params.model

            self.network_params = self.model_params.network
            backend = create_backend_from_proto(self.network_params,
                                                restore=self.checkpoint,
                                                processes=processes)
            self.network = backend.create_net(restore=self.checkpoint,
                                              weights=None,
                                              graph_type="predict",
                                              batch_size=batch_size)
            self.text_postproc = text_postproc if text_postproc else text_processor_from_proto(
                self.model_params.text_postprocessor, "post")
            self.data_preproc = data_preproc if data_preproc else data_processor_from_proto(
                self.model_params.data_preprocessor)
        elif network:
            self.model_params = None
            self.network_params = network.network_proto
            self.text_postproc = text_postproc
            self.data_preproc = data_preproc
            if not codec:
                raise Exception(
                    "A codec is required if preloaded network is used.")
        else:
            raise Exception(
                "Either a checkpoint or a existing backend must be provided")

        self.codec = codec if codec else Codec(self.model_params.codec.charset)
        self.out_to_in_trans = OutputToInputTransformer(
            self.data_preproc, self.network)
Exemplo n.º 8
0
    def train(self, auto_compute_codec=False, progress_bar=False):
        """ Launch the training

        Parameters
        ----------
        progress_bar : bool
            Show or hide any progress bar

        """
        checkpoint_params = self.checkpoint_params

        train_start_time = time.time() + self.checkpoint_params.total_time

        # load training dataset
        if self.preload_training:
            self.dataset.preload(processes=checkpoint_params.processes, progress_bar=progress_bar)

        # load validation dataset
        if self.validation_dataset and self.preload_validation:
            self.validation_dataset.preload(processes=checkpoint_params.processes, progress_bar=progress_bar)

        # compute the codec
        if self.codec:
            codec = self.codec
        else:
            if len(self.codec_whitelist) == 0 or auto_compute_codec:
                codec = Codec.from_input_dataset([self.dataset, self.validation_dataset],
                                                 whitelist=self.codec_whitelist, progress_bar=True)
            else:
                codec = Codec.from_texts([], whitelist=self.codec_whitelist)

        # create backend
        network_params = checkpoint_params.model.network
        network_params.features = checkpoint_params.model.line_height
        network_params.classes = len(codec)
        if self.weights:
            # if we load the weights, take care of codec changes as-well
            ckpt = Checkpoint(self.weights + '.json', auto_update=self.auto_update_checkpoints)
            restore_checkpoint_params = ckpt.checkpoint
            restore_model_params = restore_checkpoint_params.model

            # checks
            if checkpoint_params.model.line_height != network_params.features:
                raise Exception("The model to restore has a line height of {} but a line height of {} is requested".format(
                    network_params.features, checkpoint_params.model.line_height
                ))

            # create codec of the same type
            restore_codec = codec.__class__(restore_model_params.codec.charset)
            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(codec)
            codec = restore_codec
            print("Codec changes: {} deletions, {} appends".format(len(codec_changes[0]), len(codec_changes[1])))
            # The actual weight/bias matrix will be changed after loading the old weights
            if all([c == 0 for c in codec_changes]):
                codec_changes = None  # No codec changes
        else:
            codec_changes = None

        # store the new codec
        checkpoint_params.model.codec.charset[:] = codec.charset
        print("CODEC: {}".format(codec.charset))

        backend = create_backend_from_proto(network_params,
                                            weights=self.weights,
                                            )
        train_net = backend.create_net(self.dataset, codec, restore=None, weights=self.weights, graph_type="train", batch_size=checkpoint_params.batch_size)
        test_net = backend.create_net(self.validation_dataset, codec, restore=None, weights=self.weights, graph_type="test", batch_size=checkpoint_params.batch_size)
        if codec_changes:
            # only required on one net, since the other shares the same variables
            train_net.realign_model_labels(*codec_changes)

        train_net.prepare()
        test_net.prepare()

        if checkpoint_params.current_stage == 0:
            self._run_train(train_net, test_net, codec, train_start_time, progress_bar)

        if checkpoint_params.data_aug_retrain_on_original and self.dataset.data_augmenter and self.dataset.data_augmentation_amount > 0:
            print("Starting training on original data only")
            if checkpoint_params.current_stage == 0:
                checkpoint_params.current_stage = 1
                checkpoint_params.iter = 0
                checkpoint_params.early_stopping_best_at_iter = 0
                checkpoint_params.early_stopping_best_cur_nbest = 0
                checkpoint_params.early_stopping_best_accuracy = 0

            self.dataset.generate_only_non_augmented = True  # this is the important line!
            train_net.prepare()
            test_net.prepare()
            self._run_train(train_net, test_net, codec, train_start_time, progress_bar)

        train_net.prepare()  # reset the state
        test_net.prepare()   # to prevent blocking of tensorflow on shutdown