Exemple #1
0
 def __init__(self,
              checkpoint_params,
              dataset,
              validation_dataset=None,
              txt_preproc=None,
              txt_postproc=None,
              data_preproc=None,
              data_augmenter=None,
              n_augmentations=0,
              weights=None,
              codec=None,
              codec_whitelist=[]):
     self.checkpoint_params = checkpoint_params
     self.dataset = dataset
     self.validation_dataset = validation_dataset
     self.data_augmenter = data_augmenter
     self.n_augmentations = n_augmentations
     self.txt_preproc = txt_preproc if txt_preproc else text_processor_from_proto(
         checkpoint_params.model.text_preprocessor, "pre")
     self.txt_postproc = txt_postproc if txt_postproc else text_processor_from_proto(
         checkpoint_params.model.text_postprocessor, "post")
     self.data_preproc = data_preproc if data_preproc else data_processor_from_proto(
         checkpoint_params.model.data_preprocessor)
     self.weights = checkpoint_path(weights) if weights else None
     self.codec = codec
     self.codec_whitelist = codec_whitelist
Exemple #2
0
    def __init__(self, baseurl, cachefile, login, password=None):
        """ Create a nashi client
        Parameters
        ----------
        baseurl : web address of nashi instance
        cachefile : filename of hdf5-cache
        login : user for nashi
        password : asks for user input if empty
        """
        self.baseurl = baseurl
        self.session = None
        self.traindata = None
        self.recogdata = None
        self.valdata = None
        self.bookcache = {}
        self.cachefile = cachefile
        self.login(login, password)

        params = DataPreprocessorParams()
        params.line_height = 48
        params.pad = 16
        params.pad_value = 1
        params.no_invert = False
        params.no_transpose = False
        self.data_proc = MultiDataProcessor([
            DataRangeNormalizer(),
            CenterNormalizer(params),
            FinalPreparation(params, as_uint8=True),
        ])

        # Text pre processing (reading)
        preproc = TextProcessorParams()
        preproc.type = TextProcessorParams.MULTI_NORMALIZER
        default_text_normalizer_params(preproc.children.add(), default="NFC")
        default_text_regularizer_params(preproc.children.add(), groups=["extended"])
        strip_processor_params = preproc.children.add()
        strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER
        self.txt_preproc = text_processor_from_proto(preproc, "pre")

        # Text post processing (prediction)
        postproc = TextProcessorParams()
        postproc.type = TextProcessorParams.MULTI_NORMALIZER
        default_text_normalizer_params(postproc.children.add(), default="NFC")
        default_text_regularizer_params(postproc.children.add(), groups=["extended"])
        strip_processor_params = postproc.children.add()
        strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER
        self.text_postproc = text_processor_from_proto(postproc, "post")

        # BIDI text preprocessing
        bidi_processor_params = preproc.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_RTL
        self.bidi_preproc = text_processor_from_proto(preproc, "pre")

        bidi_processor_params = postproc.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO
        self.bidi_postproc = text_processor_from_proto(postproc, "post")
Exemple #3
0
    def __init__(self,
                 checkpoint=None,
                 text_postproc=None,
                 data_preproc=None,
                 codec=None,
                 backend=None):
        self.backend = backend
        self.checkpoint = checkpoint
        self.codec = codec

        if checkpoint:
            if backend:
                raise Exception(
                    "Either a checkpoint or a backend can be provided")

            with open(checkpoint + '.json', 'r') as f:
                checkpoint_params = json_format.Parse(f.read(),
                                                      CheckpointParams())
                self.model_params = checkpoint_params.model

            self.network_params = self.model_params.network
            self.backend = create_backend_from_proto(self.network_params,
                                                     restore=self.checkpoint)
            self.text_postproc = text_postproc if text_postproc else text_processor_from_proto(
                self.model_params.text_postprocessor, "post")
            self.data_preproc = data_preproc if data_preproc else data_processor_from_proto(
                self.model_params.data_preprocessor)
        elif backend:
            self.model_params = None
            self.network_params = backend.network_proto
            self.text_postproc = text_postproc
            self.data_preproc = data_preproc
        else:
            raise Exception(
                "Either a checkpoint or a existing backend must be provided")
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--files",
                        type=str,
                        nargs="+",
                        required=True,
                        help="Text files to apply text processing")
    parser.add_argument("--text_regularization",
                        type=str,
                        nargs="+",
                        default=["extended"],
                        help="Text regularization to apply.")
    parser.add_argument(
        "--text_normalization",
        type=str,
        default="NFC",
        help="Unicode text normalization to apply. Defaults to NFC")
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--dry_run",
                        action="store_true",
                        help="No not overwrite files, just run")

    args = parser.parse_args()

    # Text pre processing (reading)
    preproc = TextProcessorParams()
    preproc.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(preproc.children.add(),
                                   default=args.text_normalization)
    default_text_regularizer_params(preproc.children.add(),
                                    groups=args.text_regularization)
    strip_processor_params = preproc.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    txt_proc = text_processor_from_proto(preproc, "pre")

    print("Resolving files")
    text_files = glob_all(args.files)

    for path in tqdm(text_files, desc="Processing", total=len(text_files)):
        with codecs.open(path, "r", "utf-8") as f:
            content = f.read()

        content = txt_proc.apply(content)

        if args.verbose:
            print(content)

        if not args.dry_run:
            with codecs.open(path, "w", "utf-8") as f:
                f.write(content)
Exemple #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--eval_imgs",
                        type=str,
                        nargs="+",
                        required=True,
                        help="The evaluation files")
    parser.add_argument("--checkpoint",
                        type=str,
                        nargs="+",
                        default=[],
                        help="Path to the checkpoint without file extension")
    parser.add_argument("-j",
                        "--processes",
                        type=int,
                        default=1,
                        help="Number of processes to use")
    parser.add_argument("--verbose",
                        action="store_true",
                        help="Print additional information")
    parser.add_argument(
        "--voter",
        type=str,
        nargs="+",
        default=[
            "sequence_voter", "confidence_voter_default_ctc",
            "confidence_voter_fuzzy_ctc"
        ],
        help=
        "The voting algorithm to use. Possible values: confidence_voter_default_ctc (default), "
        "confidence_voter_fuzzy_ctc, sequence_voter")
    parser.add_argument("--batch_size",
                        type=int,
                        default=10,
                        help="The batch size for prediction")
    parser.add_argument("--dump",
                        type=str,
                        help="Dump the output as serialized pickle object")
    parser.add_argument(
        "--no_skip_invalid_gt",
        action="store_true",
        help="Do no skip invalid gt, instead raise an exception.")

    args = parser.parse_args()

    # allow user to specify json file for model definition, but remove the file extension
    # for further processing
    args.checkpoint = [(cp[:-5] if cp.endswith(".json") else cp)
                       for cp in args.checkpoint]

    # load files
    gt_images = sorted(glob_all(args.eval_imgs))
    gt_txts = [
        split_all_ext(path)[0] + ".gt.txt"
        for path in sorted(glob_all(args.eval_imgs))
    ]

    dataset = FileDataSet(images=gt_images,
                          texts=gt_txts,
                          skip_invalid=not args.no_skip_invalid_gt)

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception(
            "Empty dataset provided. Check your files argument (got {})!".
            format(args.files))

    # predict for all models
    n_models = len(args.checkpoint)
    predictor = MultiPredictor(checkpoints=args.checkpoint,
                               batch_size=args.batch_size,
                               processes=args.processes)
    do_prediction = predictor.predict_dataset(dataset, progress_bar=True)

    voters = []
    all_voter_sentences = []
    all_prediction_sentences = [[] for _ in range(n_models)]

    for voter in args.voter:
        # create voter
        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value(voter.upper())
        voters.append(voter_from_proto(voter_params))
        all_voter_sentences.append([])

    for prediction, sample in do_prediction:
        for sent, p in zip(all_prediction_sentences, prediction):
            sent.append(p.sentence)

        # vote results
        for voter, voter_sentences in zip(voters, all_voter_sentences):
            voter_sentences.append(
                voter.vote_prediction_result(prediction).sentence)

    # evaluation
    text_preproc = text_processor_from_proto(
        predictor.predictors[0].model_params.text_preprocessor)
    evaluator = Evaluator(text_preprocessor=text_preproc)
    evaluator.preload_gt(gt_dataset=dataset, progress_bar=True)

    def single_evaluation(predicted_sentences):
        if len(predicted_sentences) != len(dataset):
            raise Exception(
                "Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did "
                "not succeed".format(len(dataset), len(predicted_sentences)))

        pred_data_set = RawDataSet(texts=predicted_sentences)

        r = evaluator.run(pred_dataset=pred_data_set,
                          progress_bar=True,
                          processes=args.processes)

        return r

    full_evaluation = {}
    for id, data in [
        (str(i), sent) for i, sent in enumerate(all_prediction_sentences)
    ] + list(zip(args.voter, all_voter_sentences)):
        full_evaluation[id] = {"eval": single_evaluation(data), "data": data}

    if args.verbose:
        print(full_evaluation)

    if args.dump:
        import pickle
        with open(args.dump, 'wb') as f:
            pickle.dump(
                {
                    "full": full_evaluation,
                    "gt_txts": gt_txts,
                    "gt": dataset.text_samples()
                }, f)
Exemple #6
0
    def __init__(self, checkpoint=None, text_postproc=None, data_preproc=None, codec=None, network=None,
                 batch_size=1, processes=1,
                 auto_update_checkpoints=True,
                 with_gt=False,
                 ):
        """ Predicting a dataset based on a trained model

        Parameters
        ----------
        checkpoint : str, optional
            filepath of the checkpoint of the network to load, alternatively you can directly use a loaded `network`
        text_postproc : TextProcessor, optional
            text processor to be applied on the predicted sentence for the final output.
            If loaded from a checkpoint the text processor will be loaded from it.
        data_preproc : DataProcessor, optional
            data processor (must be the same as of the trained model) to be applied to the input image.
            If loaded from a checkpoint the text processor will be loaded from it.
        codec : Codec, optional
            Codec of the deep net to use for decoding. This parameter is only required if a custom codec is used,
            or a `network` has been provided instead of a `checkpoint`
        network : ModelInterface, optional
            DNN instance to used. Alternatively you can provide a `checkpoint` to load a network.
        batch_size : int, optional
            Batch size to use for prediction
        processes : int, optional
            The number of processes to use for prediction
        auto_update_checkpoints : bool, optional
            Update old models automatically (this will change the checkpoint files)
        with_gt : bool, optional
            The prediction will also output the ground truth if available else None
        """
        self.network = network
        self.checkpoint = checkpoint
        self.processes = processes
        self.auto_update_checkpoints = auto_update_checkpoints
        self.with_gt = with_gt

        if checkpoint:
            if network:
                raise Exception("Either a checkpoint or a network can be provided")

            ckpt = Checkpoint(checkpoint, auto_update=self.auto_update_checkpoints)
            self.checkpoint = ckpt.ckpt_path
            checkpoint_params = ckpt.checkpoint
            self.model_params = checkpoint_params.model
            self.codec = codec if codec else Codec(self.model_params.codec.charset)

            self.network_params = self.model_params.network
            backend = create_backend_from_proto(self.network_params, restore=self.checkpoint, processes=processes)
            self.text_postproc = text_postproc if text_postproc else text_processor_from_proto(self.model_params.text_postprocessor, "post")
            self.data_preproc = data_preproc if data_preproc else data_processor_from_proto(self.model_params.data_preprocessor)
            self.network = backend.create_net(
                dataset=None,
                codec=self.codec,
                restore=self.checkpoint, weights=None, graph_type="predict", batch_size=batch_size)
        elif network:
            self.codec = codec
            self.model_params = None
            self.network_params = network.network_proto
            self.text_postproc = text_postproc
            self.data_preproc = data_preproc
            if not codec:
                raise Exception("A codec is required if preloaded network is used.")
        else:
            raise Exception("Either a checkpoint or a existing backend must be provided")

        self.out_to_in_trans = OutputToInputTransformer(self.data_preproc, self.network)
Exemple #7
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument(
        "--gt",
        nargs="+",
        required=True,
        help="Ground truth files (.gt.txt extension). "
        "Optionally, you can pass a single json file defining all parameters.")
    parser.add_argument(
        "--pred",
        nargs="+",
        default=None,
        help=
        "Prediction files if provided. Else files with .pred.txt are expected at the same "
        "location as the gt.")
    parser.add_argument("--pred_dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument("--pred_ext",
                        type=str,
                        default=".pred.txt",
                        help="Extension of the predicted text files")
    parser.add_argument(
        "--n_confusions",
        type=int,
        default=10,
        help=
        "Only print n most common confusions. Defaults to 10, use -1 for all.")
    parser.add_argument(
        "--n_worst_lines",
        type=int,
        default=0,
        help="Print the n worst recognized text lines with its error")
    parser.add_argument(
        "--xlsx_output",
        type=str,
        help="Optionally write a xlsx file with the evaluation results")
    parser.add_argument("--num_threads",
                        type=int,
                        default=1,
                        help="Number of threads to use for evaluation")
    parser.add_argument(
        "--non_existing_file_handling_mode",
        type=str,
        default="error",
        help=
        "How to handle non existing .pred.txt files. Possible modes: skip, empty, error. "
        "'Skip' will simply skip the evaluation of that file (not counting it to errors). "
        "'Empty' will handle this file as would it be empty (fully checking for errors)."
        "'Error' will throw an exception if a file is not existing. This is the default behaviour."
    )
    parser.add_argument("--skip_empty_gt",
                        action="store_true",
                        default=False,
                        help="Ignore lines of the gt that are empty.")
    parser.add_argument("--no_progress_bars",
                        action="store_true",
                        help="Do not show any progress bars")
    parser.add_argument(
        "--checkpoint",
        type=str,
        default=None,
        help=
        "Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)"
    )

    # page xml specific args
    parser.add_argument("--pagexml_gt_text_index", default=0)
    parser.add_argument("--pagexml_pred_text_index", default=1)

    args = parser.parse_args()

    # check if loading a json file
    if len(args.gt) == 1 and args.gt[0].endswith("json"):
        with open(args.gt[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    print("Resolving files")
    gt_files = sorted(glob_all(args.gt))

    if args.pred:
        pred_files = sorted(glob_all(args.pred))
    else:
        pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files]
        args.pred_dataset = args.dataset

    if args.non_existing_file_handling_mode.lower() == "skip":
        non_existing_pred = [p for p in pred_files if not os.path.exists(p)]
        for f in non_existing_pred:
            idx = pred_files.index(f)
            del pred_files[idx]
            del gt_files[idx]

    text_preproc = None
    if args.checkpoint:
        with open(
                args.checkpoint if args.checkpoint.endswith(".json") else
                args.checkpoint + '.json', 'r') as f:
            checkpoint_params = json_format.Parse(f.read(), CheckpointParams())
            text_preproc = text_processor_from_proto(
                checkpoint_params.model.text_preprocessor)

    non_existing_as_empty = args.non_existing_file_handling_mode.lower(
    ) != "error "
    gt_data_set = create_dataset(
        args.dataset,
        DataSetMode.EVAL,
        texts=gt_files,
        non_existing_as_empty=non_existing_as_empty,
        args={'text_index': args.pagexml_gt_text_index},
    )
    pred_data_set = create_dataset(
        args.pred_dataset,
        DataSetMode.EVAL,
        texts=pred_files,
        non_existing_as_empty=non_existing_as_empty,
        args={'text_index': args.pagexml_pred_text_index},
    )

    evaluator = Evaluator(text_preprocessor=text_preproc,
                          skip_empty_gt=args.skip_empty_gt)
    r = evaluator.run(gt_dataset=gt_data_set,
                      pred_dataset=pred_data_set,
                      processes=args.num_threads,
                      progress_bar=not args.no_progress_bars)

    # TODO: More output
    print("Evaluation result")
    print("=================")
    print("")
    print(
        "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)"
        .format(r["avg_ler"], r["total_char_errs"], r["total_chars"],
                r["total_sync_errs"]))

    # sort descending
    print_confusions(r, args.n_confusions)

    print_worst_lines(r, gt_data_set.samples(), args.n_worst_lines)

    if args.xlsx_output:
        write_xlsx(args.xlsx_output, [{
            "prefix": "evaluation",
            "results": r,
            "gt_files": gt_files,
        }])
Exemple #8
0
    def __init__(
        self,
        checkpoint_params,
        dataset,
        validation_dataset=None,
        txt_preproc=None,
        txt_postproc=None,
        data_preproc=None,
        data_augmenter: DataAugmenter = None,
        n_augmentations=0,
        weights=None,
        codec=None,
        codec_whitelist=None,
        auto_update_checkpoints=True,
        preload_training=False,
        preload_validation=False,
    ):
        """Train a DNN using given preprocessing, weights, and data

        The purpose of the Trainer is handle a default training mechanism.
        As required input it expects a `dataset` and hyperparameters (`checkpoint_params`).

        The steps are
            1. Loading and preprocessing of the dataset
            2. Computation of the codec
            3. Construction of the DNN in the desired Deep Learning Framework
            4. Launch of the training

        During the training the Trainer will perform validation checks if a `validation_dataset` is given
        to determine the best model.
        Furthermore, the current status is printet and checkpoints are written.

        Parameters
        ----------
        checkpoint_params : CheckpointParams
            Proto parameter object that defines all hyperparameters of the model
        dataset : Dataset
            The Dataset used for training
        validation_dataset : Dataset, optional
            The Dataset used for validation, i.e. choosing the best model
        txt_preproc : TextProcessor, optional
            Text preprocessor that is applied on loaded text, before the Codec is computed
        txt_postproc : TextProcessor, optional
            Text processor that is applied on the loaded GT text and on the prediction to receive the final result
        data_preproc : DataProcessor, optional
            Preprocessing for the image lines (e. g. padding, inversion, deskewing, ...)
        data_augmenter : DataAugmenter, optional
            A DataAugmenter object to use for data augmentation. Count is set by `n_augmentations`
        n_augmentations : int, optional
            The number of augmentations performend by the `data_augmenter`
        weights : str, optional
            Path to a trained model for loading its weights
        codec : Codec, optional
            If provided the Codec will not be computed automaticall based on the GT, but instead `codec` will be used
        codec_whitelist : obj:`list` of :obj:`str`
            List of characters to be kept when the loaded `weights` have a different codec than the new one.
        """
        self.checkpoint_params = checkpoint_params
        self.txt_preproc = txt_preproc if txt_preproc else text_processor_from_proto(
            checkpoint_params.model.text_preprocessor, "pre")
        self.txt_postproc = txt_postproc if txt_postproc else text_processor_from_proto(
            checkpoint_params.model.text_postprocessor, "post")
        self.data_preproc = data_preproc if data_preproc else data_processor_from_proto(
            checkpoint_params.model.data_preprocessor)
        self.weights = checkpoint_path(weights) if weights else None
        self.codec = codec
        self.codec_whitelist = [] if codec_whitelist is None else codec_whitelist
        self.auto_update_checkpoints = auto_update_checkpoints
        self.dataset = InputDataset(dataset, self.data_preproc,
                                    self.txt_preproc, data_augmenter,
                                    n_augmentations)
        self.validation_dataset = InputDataset(
            validation_dataset, self.data_preproc,
            self.txt_preproc) if validation_dataset else None
        self.preload_training = preload_training
        self.preload_validation = preload_validation

        if len(self.dataset) == 0:
            raise Exception("Dataset is empty.")

        if self.validation_dataset and len(self.validation_dataset) == 0:
            raise Exception(
                "Validation dataset is empty. Provide valid validation data for early stopping."
            )
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--version',
                        action='version',
                        version='%(prog)s v' + __version__)
    parser.add_argument(
        "--files",
        nargs="+",
        help=
        "List all image files that shall be processed. Ground truth fils with the same "
        "base name but with '.gt.txt' as extension are required at the same location",
        required=True)
    parser.add_argument(
        "--text_files",
        nargs="+",
        default=None,
        help="Optional list of GT files if they are in other directory")
    parser.add_argument(
        "--gt_extension",
        default=None,
        help="Default extension of the gt files (expected to exist in same dir)"
    )
    parser.add_argument("--dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument("--line_height",
                        type=int,
                        default=48,
                        help="The line height")
    parser.add_argument("--pad",
                        type=int,
                        default=16,
                        help="Padding (left right) of the line")
    parser.add_argument("--processes",
                        type=int,
                        default=1,
                        help="The number of threads to use for all operations")

    parser.add_argument("--n_cols", type=int, default=1)
    parser.add_argument("--n_rows", type=int, default=5)
    parser.add_argument("--select", type=int, nargs="+", default=[])

    # text normalization/regularization
    parser.add_argument(
        "--n_augmentations",
        type=float,
        default=0,
        help=
        "Amount of data augmentation per line (done before training). If this number is < 1 "
        "the amount is relative.")
    parser.add_argument("--text_regularization",
                        type=str,
                        nargs="+",
                        default=["extended"],
                        help="Text regularization to apply.")
    parser.add_argument(
        "--text_normalization",
        type=str,
        default="NFC",
        help="Unicode text normalization to apply. Defaults to NFC")
    parser.add_argument("--data_preprocessing",
                        nargs="+",
                        type=DataPreprocessorParams.Type.Value,
                        choices=DataPreprocessorParams.Type.values(),
                        default=[DataPreprocessorParams.DEFAULT_NORMALIZER])

    args = parser.parse_args()

    # Text/Data processing
    if args.data_preprocessing is None or len(args.data_preprocessing) == 0:
        args.data_preprocessing = [DataPreprocessorParams.DEFAULT_NORMALIZER]

    data_preprocessor = DataPreprocessorParams()
    data_preprocessor.type = DataPreprocessorParams.MULTI_NORMALIZER
    for preproc in args.data_preprocessing:
        pp = data_preprocessor.children.add()
        pp.type = preproc
        pp.line_height = args.line_height
        pp.pad = args.pad

    # Text pre processing (reading)
    text_preprocessor = TextProcessorParams()
    text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(text_preprocessor.children.add(),
                                   default=args.text_normalization)
    default_text_regularizer_params(text_preprocessor.children.add(),
                                    groups=args.text_regularization)
    strip_processor_params = text_preprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    text_preprocessor = text_processor_from_proto(text_preprocessor)
    data_preprocessor = data_processor_from_proto(data_preprocessor)

    print("Resolving input files")
    input_image_files = sorted(glob_all(args.files))
    if not args.text_files:
        if args.gt_extension:
            gt_txt_files = [
                split_all_ext(f)[0] + args.gt_extension
                for f in input_image_files
            ]
        else:
            gt_txt_files = [None] * len(input_image_files)
    else:
        gt_txt_files = sorted(glob_all(args.text_files))
        input_image_files, gt_txt_files = keep_files_with_same_file_name(
            input_image_files, gt_txt_files)
        for img, gt in zip(input_image_files, gt_txt_files):
            if split_all_ext(os.path.basename(img))[0] != split_all_ext(
                    os.path.basename(gt))[0]:
                raise Exception(
                    "Expected identical basenames of file: {} and {}".format(
                        img, gt))

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception(
            "Some image are occurring more than once in the data set.")

    dataset = create_dataset(
        args.dataset,
        DataSetMode.TRAIN,
        images=input_image_files,
        texts=gt_txt_files,
        non_existing_as_empty=True,
    )

    if len(args.select) == 0:
        args.select = range(len(dataset.samples()))
        dataset._samples = dataset.samples()
    else:
        dataset._samples = [dataset.samples()[i] for i in args.select]

    samples = dataset.samples()

    print("Found {} files in the dataset".format(len(dataset)))

    with StreamingInputDataset(
            dataset,
            data_preprocessor,
            text_preprocessor,
            SimpleDataAugmenter(),
            args.n_augmentations,
    ) as input_dataset:
        f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all')
        row, col = 0, 0
        for i, (id, sample) in enumerate(
                zip(args.select, input_dataset.generator(args.processes))):
            line, text, params = sample
            if args.n_cols == 1:
                ax[row].imshow(line.transpose())
                ax[row].set_title("ID: {}\n{}".format(id, text))
            else:
                ax[row, col].imshow(line.transpose())
                ax[row, col].set_title("ID: {}\n{}".format(id, text))

            row += 1
            if row == args.n_rows:
                row = 0
                col += 1

            if col == args.n_cols or i == len(samples) - 1:
                plt.show()
                f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all')
                row, col = 0, 0
Exemple #10
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--eval_imgs", type=str, nargs="+", required=True,
                        help="The evaluation files")
    parser.add_argument("--eval_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--checkpoint", type=str, nargs="+", default=[],
                        help="Path to the checkpoint without file extension")
    parser.add_argument("-j", "--processes", type=int, default=1,
                        help="Number of processes to use")
    parser.add_argument("--verbose", action="store_true",
                        help="Print additional information")
    parser.add_argument("--voter", type=str, nargs="+", default=["sequence_voter", "confidence_voter_default_ctc", "confidence_voter_fuzzy_ctc"],
                        help="The voting algorithm to use. Possible values: confidence_voter_default_ctc (default), "
                             "confidence_voter_fuzzy_ctc, sequence_voter")
    parser.add_argument("--batch_size", type=int, default=10,
                        help="The batch size for prediction")
    parser.add_argument("--dump", type=str,
                        help="Dump the output as serialized pickle object")
    parser.add_argument("--no_skip_invalid_gt", action="store_true",
                        help="Do no skip invalid gt, instead raise an exception.")

    args = parser.parse_args()

    # allow user to specify json file for model definition, but remove the file extension
    # for further processing
    args.checkpoint = [(cp[:-5] if cp.endswith(".json") else cp) for cp in args.checkpoint]

    # load files
    gt_images = sorted(glob_all(args.eval_imgs))
    gt_txts = [split_all_ext(path)[0] + ".gt.txt" for path in sorted(glob_all(args.eval_imgs))]

    dataset = create_dataset(
        args.eval_dataset,
        DataSetMode.TRAIN,
        images=gt_images,
        texts=gt_txts,
        skip_invalid=not args.no_skip_invalid_gt
    )

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception("Empty dataset provided. Check your files argument (got {})!".format(args.files))

    # predict for all models
    n_models = len(args.checkpoint)
    predictor = MultiPredictor(checkpoints=args.checkpoint, batch_size=args.batch_size, processes=args.processes)
    do_prediction = predictor.predict_dataset(dataset, progress_bar=True)

    voters = []
    all_voter_sentences = []
    all_prediction_sentences = [[] for _ in range(n_models)]

    for voter in args.voter:
        # create voter
        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value(voter.upper())
        voters.append(voter_from_proto(voter_params))
        all_voter_sentences.append([])

    for prediction, sample in do_prediction:
        for sent, p in zip(all_prediction_sentences, prediction):
            sent.append(p.sentence)

        # vote results
        for voter, voter_sentences in zip(voters, all_voter_sentences):
            voter_sentences.append(voter.vote_prediction_result(prediction).sentence)

    # evaluation
    text_preproc = text_processor_from_proto(predictor.predictors[0].model_params.text_preprocessor)
    evaluator = Evaluator(text_preprocessor=text_preproc)
    evaluator.preload_gt(gt_dataset=dataset, progress_bar=True)

    def single_evaluation(predicted_sentences):
        if len(predicted_sentences) != len(dataset):
            raise Exception("Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did "
                            "not succeed".format(len(dataset), len(predicted_sentences)))

        pred_data_set = create_dataset(
            DataSetType.RAW,
            DataSetMode.EVAL,
            texts=predicted_sentences)

        r = evaluator.run(pred_dataset=pred_data_set, progress_bar=True, processes=args.processes)

        return r

    full_evaluation = {}
    for id, data in [(str(i), sent) for i, sent in enumerate(all_prediction_sentences)] + list(zip(args.voter, all_voter_sentences)):
        full_evaluation[id] = {"eval": single_evaluation(data), "data": data}

    if args.verbose:
        print(full_evaluation)

    if args.dump:
        import pickle
        with open(args.dump, 'wb') as f:
            pickle.dump({"full": full_evaluation, "gt_txts": gt_txts, "gt": dataset.text_samples()}, f)
Exemple #11
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--gt",
                        nargs="+",
                        required=True,
                        help="Ground truth files (.gt.txt extension)")
    parser.add_argument(
        "--pred",
        nargs="+",
        default=None,
        help=
        "Prediction files if provided. Else files with .pred.txt are expected at the same "
        "location as the gt.")
    parser.add_argument("--pred_ext",
                        type=str,
                        default=".pred.txt",
                        help="Extension of the predicted text files")
    parser.add_argument(
        "--n_confusions",
        type=int,
        default=10,
        help=
        "Only print n most common confusions. Defaults to 10, use -1 for all.")
    parser.add_argument(
        "--n_worst_lines",
        type=int,
        default=0,
        help="Print the n worst recognized text lines with its error")
    parser.add_argument(
        "--xlsx_output",
        type=str,
        help="Optionally write a xlsx file with the evaluation results")
    parser.add_argument("--num_threads",
                        type=int,
                        default=1,
                        help="Number of threads to use for evaluation")
    parser.add_argument(
        "--non_existing_file_handling_mode",
        type=str,
        default="error",
        help=
        "How to handle non existing .pred.txt files. Possible modes: skip, empty, error. "
        "'Skip' will simply skip the evaluation of that file (not counting it to errors). "
        "'Empty' will handle this file as would it be empty (fully checking for errors)."
        "'Error' will throw an exception if a file is not existing. This is the default behaviour."
    )
    parser.add_argument("--no_progress_bars",
                        action="store_true",
                        help="Do not show any progress bars")
    parser.add_argument(
        "--checkpoint",
        type=str,
        default=None,
        help=
        "Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)"
    )

    args = parser.parse_args()

    print("Resolving files")
    gt_files = sorted(glob_all(args.gt))

    if args.pred:
        pred_files = sorted(glob_all(args.pred))
        if len(pred_files) != len(gt_files):
            raise Exception(
                "Mismatch in the number of gt and pred files: {} vs {}".format(
                    len(gt_files), len(pred_files)))
    else:
        pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files]

    if args.non_existing_file_handling_mode.lower() == "skip":
        non_existing_pred = [p for p in pred_files if not os.path.exists(p)]
        for f in non_existing_pred:
            idx = pred_files.index(f)
            del pred_files[idx]
            del gt_files[idx]

    text_preproc = None
    if args.checkpoint:
        with open(
                args.checkpoint if args.checkpoint.endswith(".json") else
                args.checkpoint + '.json', 'r') as f:
            checkpoint_params = json_format.Parse(f.read(), CheckpointParams())
            text_preproc = text_processor_from_proto(
                checkpoint_params.model.text_preprocessor)

    non_existing_as_empty = args.non_existing_file_handling_mode.lower(
    ) == "empty"
    gt_data_set = FileDataSet(texts=gt_files,
                              non_existing_as_empty=non_existing_as_empty)
    pred_data_set = FileDataSet(texts=pred_files,
                                non_existing_as_empty=non_existing_as_empty)

    evaluator = Evaluator(text_preprocessor=text_preproc)
    r = evaluator.run(gt_dataset=gt_data_set,
                      pred_dataset=pred_data_set,
                      processes=args.num_threads,
                      progress_bar=not args.no_progress_bars)

    # TODO: More output
    print("Evaluation result")
    print("=================")
    print("")
    print(
        "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)"
        .format(r["avg_ler"], r["total_char_errs"], r["total_chars"],
                r["total_sync_errs"]))

    # sort descending
    print_confusions(r, args.n_confusions)

    print_worst_lines(r, gt_files, gt_data_set.text_samples(),
                      pred_data_set.text_samples(), args.n_worst_lines)

    if args.xlsx_output:
        write_xlsx(args.xlsx_output, [{
            "prefix": "evaluation",
            "results": r,
            "gt_files": gt_files,
            "gts": gt_data_set.text_samples(),
            "preds": pred_data_set.text_samples()
        }])
Exemple #12
0
    def __init__(self, checkpoint=None, text_postproc=None, data_preproc=None, codec=None, network=None,
                 batch_size=1, processes=1,
                 auto_update_checkpoints=True,
                 with_gt=False,
                 ):
        """ Predicting a dataset based on a trained model

        Parameters
        ----------
        checkpoint : str, optional
            filepath of the checkpoint of the network to load, alternatively you can directly use a loaded `network`
        text_postproc : TextProcessor, optional
            text processor to be applied on the predicted sentence for the final output.
            If loaded from a checkpoint the text processor will be loaded from it.
        data_preproc : DataProcessor, optional
            data processor (must be the same as of the trained model) to be applied to the input image.
            If loaded from a checkpoint the text processor will be loaded from it.
        codec : Codec, optional
            Codec of the deep net to use for decoding. This parameter is only required if a custom codec is used,
            or a `network` has been provided instead of a `checkpoint`
        network : ModelInterface, optional
            DNN instance to used. Alternatively you can provide a `checkpoint` to load a network.
        batch_size : int, optional
            Batch size to use for prediction
        processes : int, optional
            The number of processes to use for prediction
        auto_update_checkpoints : bool, optional
            Update old models automatically (this will change the checkpoint files)
        with_gt : bool, optional
            The prediction will also output the ground truth if available else None
        """
        self.network = network
        self.checkpoint = checkpoint
        self.processes = processes
        self.auto_update_checkpoints = auto_update_checkpoints
        self.with_gt = with_gt

        if checkpoint:
            if network:
                raise Exception("Either a checkpoint or a network can be provided")

            ckpt = Checkpoint(checkpoint, auto_update=self.auto_update_checkpoints)
            checkpoint_params = ckpt.checkpoint
            self.model_params = checkpoint_params.model
            self.codec = codec if codec else Codec(self.model_params.codec.charset)

            self.network_params = self.model_params.network
            backend = create_backend_from_proto(self.network_params, restore=self.checkpoint, processes=processes)
            self.text_postproc = text_postproc if text_postproc else text_processor_from_proto(self.model_params.text_postprocessor, "post")
            self.data_preproc = data_preproc if data_preproc else data_processor_from_proto(self.model_params.data_preprocessor)
            self.network = backend.create_net(
                dataset=None,
                codec=self.codec,
                restore=self.checkpoint, weights=None, graph_type="predict", batch_size=batch_size)
        elif network:
            self.codec = codec
            self.model_params = None
            self.network_params = network.network_proto
            self.text_postproc = text_postproc
            self.data_preproc = data_preproc
            if not codec:
                raise Exception("A codec is required if preloaded network is used.")
        else:
            raise Exception("Either a checkpoint or a existing backend must be provided")

        self.out_to_in_trans = OutputToInputTransformer(self.data_preproc, self.network)
Exemple #13
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--gt", nargs="+", required=True,
                        help="Ground truth files (.gt.txt extension)")
    parser.add_argument("--pred", nargs="+", default=None,
                        help="Prediction files if provided. Else files with .pred.txt are expected at the same "
                             "location as the gt.")
    parser.add_argument("--pred_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--pred_ext", type=str, default=".pred.txt",
                        help="Extension of the predicted text files")
    parser.add_argument("--n_confusions", type=int, default=10,
                        help="Only print n most common confusions. Defaults to 10, use -1 for all.")
    parser.add_argument("--n_worst_lines", type=int, default=0,
                        help="Print the n worst recognized text lines with its error")
    parser.add_argument("--xlsx_output", type=str,
                        help="Optionally write a xlsx file with the evaluation results")
    parser.add_argument("--num_threads", type=int, default=1,
                        help="Number of threads to use for evaluation")
    parser.add_argument("--non_existing_file_handling_mode", type=str, default="error",
                        help="How to handle non existing .pred.txt files. Possible modes: skip, empty, error. "
                             "'Skip' will simply skip the evaluation of that file (not counting it to errors). "
                             "'Empty' will handle this file as would it be empty (fully checking for errors)."
                             "'Error' will throw an exception if a file is not existing. This is the default behaviour.")
    parser.add_argument("--no_progress_bars", action="store_true",
                        help="Do not show any progress bars")
    parser.add_argument("--checkpoint", type=str, default=None,
                        help="Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)")

    # page xml specific args
    parser.add_argument("--pagexml_gt_text_index", default=0)
    parser.add_argument("--pagexml_pred_text_index", default=1)


    args = parser.parse_args()

    print("Resolving files")
    gt_files = sorted(glob_all(args.gt))

    if args.pred:
        pred_files = sorted(glob_all(args.pred))
    else:
        pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files]
        args.pred_dataset = args.dataset

    if args.non_existing_file_handling_mode.lower() == "skip":
        non_existing_pred = [p for p in pred_files if not os.path.exists(p)]
        for f in non_existing_pred:
            idx = pred_files.index(f)
            del pred_files[idx]
            del gt_files[idx]

    text_preproc = None
    if args.checkpoint:
        with open(args.checkpoint if args.checkpoint.endswith(".json") else args.checkpoint + '.json', 'r') as f:
            checkpoint_params = json_format.Parse(f.read(), CheckpointParams())
            text_preproc = text_processor_from_proto(checkpoint_params.model.text_preprocessor)

    non_existing_as_empty = args.non_existing_file_handling_mode.lower() != "error "
    gt_data_set = create_dataset(
        args.dataset,
        DataSetMode.EVAL,
        texts=gt_files,
        non_existing_as_empty=non_existing_as_empty,
        args={'text_index': args.pagexml_gt_text_index},
    )
    pred_data_set = create_dataset(
        args.pred_dataset,
        DataSetMode.EVAL,
        texts=pred_files,
        non_existing_as_empty=non_existing_as_empty,
        args={'text_index': args.pagexml_pred_text_index},
    )

    evaluator = Evaluator(text_preprocessor=text_preproc)
    r = evaluator.run(gt_dataset=gt_data_set, pred_dataset=pred_data_set, processes=args.num_threads,
                      progress_bar=not args.no_progress_bars)

    # TODO: More output
    print("Evaluation result")
    print("=================")
    print("")
    print("Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)".format(
        r["avg_ler"], r["total_char_errs"], r["total_chars"], r["total_sync_errs"]))

    # sort descending
    print_confusions(r, args.n_confusions)

    print_worst_lines(r, gt_data_set.samples(), pred_data_set.text_samples(), args.n_worst_lines)

    if args.xlsx_output:
        write_xlsx(args.xlsx_output,
                   [{
                       "prefix": "evaluation",
                       "results": r,
                       "gt_files": gt_files,
                       "gts": gt_data_set.text_samples(),
                       "preds": pred_data_set.text_samples()
                   }])
Exemple #14
0
    def __init__(self,
                 checkpoint=None,
                 text_postproc=None,
                 data_preproc=None,
                 codec=None,
                 network=None,
                 batch_size=1,
                 processes=1):
        """ Predicting a dataset based on a trained model

        Parameters
        ----------
        checkpoint : str, optional
            filepath of the checkpoint of the network to load, alternatively you can directly use a loaded `network`
        text_postproc : TextProcessor, optional
            text processor to be applied on the predicted sentence for the final output.
            If loaded from a checkpoint the text processor will be loaded from it.
        data_preproc : DataProcessor, optional
            data processor (must be the same as of the trained model) to be applied to the input image.
            If loaded from a checkpoint the text processor will be loaded from it.
        codec : Codec, optional
            Codec of the deep net to use for decoding. This parameter is only required if a custom codec is used,
            or a `network` has been provided instead of a `checkpoint`
        network : ModelInterface, optional
            DNN instance to used. Alternatively you can provide a `checkpoint` to load a network.
        batch_size : int, optional
            Batch size to use for prediction
        processes : int, optional
            The number of processes to use for prediction
        """
        self.network = network
        self.checkpoint = checkpoint
        self.processes = processes

        if checkpoint:
            if network:
                raise Exception(
                    "Either a checkpoint or a network can be provided")

            with open(checkpoint + '.json', 'r') as f:
                checkpoint_params = json_format.Parse(f.read(),
                                                      CheckpointParams())
                self.model_params = checkpoint_params.model

            self.network_params = self.model_params.network
            backend = create_backend_from_proto(self.network_params,
                                                restore=self.checkpoint,
                                                processes=processes)
            self.network = backend.create_net(restore=self.checkpoint,
                                              weights=None,
                                              graph_type="predict",
                                              batch_size=batch_size)
            self.text_postproc = text_postproc if text_postproc else text_processor_from_proto(
                self.model_params.text_postprocessor, "post")
            self.data_preproc = data_preproc if data_preproc else data_processor_from_proto(
                self.model_params.data_preprocessor)
        elif network:
            self.model_params = None
            self.network_params = network.network_proto
            self.text_postproc = text_postproc
            self.data_preproc = data_preproc
            if not codec:
                raise Exception(
                    "A codec is required if preloaded network is used.")
        else:
            raise Exception(
                "Either a checkpoint or a existing backend must be provided")

        self.codec = codec if codec else Codec(self.model_params.codec.charset)
        self.out_to_in_trans = OutputToInputTransformer(
            self.data_preproc, self.network)
Exemple #15
0
    def __init__(self, checkpoint_params,
                 dataset,
                 validation_dataset=None,
                 txt_preproc=None,
                 txt_postproc=None,
                 data_preproc=None,
                 data_augmenter: DataAugmenter = None,
                 n_augmentations=0,
                 weights=None,
                 codec=None,
                 codec_whitelist=[],
                 auto_update_checkpoints=True,
                 preload_training=False,
                 preload_validation=False,
                 ):
        """Train a DNN using given preprocessing, weights, and data

        The purpose of the Trainer is handle a default training mechanism.
        As required input it expects a `dataset` and hyperparameters (`checkpoint_params`).

        The steps are
            1. Loading and preprocessing of the dataset
            2. Computation of the codec
            3. Construction of the DNN in the desired Deep Learning Framework
            4. Launch of the training

        During the training the Trainer will perform validation checks if a `validation_dataset` is given
        to determine the best model.
        Furthermore, the current status is printet and checkpoints are written.

        Parameters
        ----------
        checkpoint_params : CheckpointParams
            Proto parameter object that defines all hyperparameters of the model
        dataset : Dataset
            The Dataset used for training
        validation_dataset : Dataset, optional
            The Dataset used for validation, i.e. choosing the best model
        txt_preproc : TextProcessor, optional
            Text preprocessor that is applied on loaded text, before the Codec is computed
        txt_postproc : TextProcessor, optional
            Text processor that is applied on the loaded GT text and on the prediction to receive the final result
        data_preproc : DataProcessor, optional
            Preprocessing for the image lines (e. g. padding, inversion, deskewing, ...)
        data_augmenter : DataAugmenter, optional
            A DataAugmenter object to use for data augmentation. Count is set by `n_augmentations`
        n_augmentations : int, optional
            The number of augmentations performend by the `data_augmenter`
        weights : str, optional
            Path to a trained model for loading its weights
        codec : Codec, optional
            If provided the Codec will not be computed automaticall based on the GT, but instead `codec` will be used
        codec_whitelist : obj:`list` of :obj:`str`
            List of characters to be kept when the loaded `weights` have a different codec than the new one.
        """
        self.checkpoint_params = checkpoint_params
        self.txt_preproc = txt_preproc if txt_preproc else text_processor_from_proto(checkpoint_params.model.text_preprocessor, "pre")
        self.txt_postproc = txt_postproc if txt_postproc else text_processor_from_proto(checkpoint_params.model.text_postprocessor, "post")
        self.data_preproc = data_preproc if data_preproc else data_processor_from_proto(checkpoint_params.model.data_preprocessor)
        self.weights = checkpoint_path(weights) if weights else None
        self.codec = codec
        self.codec_whitelist = codec_whitelist
        self.auto_update_checkpoints = auto_update_checkpoints
        self.dataset = InputDataset(dataset, self.data_preproc, self.txt_preproc, data_augmenter, n_augmentations)
        self.validation_dataset = InputDataset(validation_dataset, self.data_preproc, self.txt_preproc) if validation_dataset else None
        self.preload_training = preload_training
        self.preload_validation = preload_validation

        if len(self.dataset) == 0:
            raise Exception("Dataset is empty.")

        if self.validation_dataset and len(self.validation_dataset) == 0:
            raise Exception("Validation dataset is empty. Provide valid validation data for early stopping.")