コード例 #1
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument(
        "--gt",
        nargs="+",
        required=True,
        help="Ground truth files (.gt.txt extension). "
        "Optionally, you can pass a single json file defining all parameters.")
    parser.add_argument(
        "--pred",
        nargs="+",
        default=None,
        help=
        "Prediction files if provided. Else files with .pred.txt are expected at the same "
        "location as the gt.")
    parser.add_argument("--pred_dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument("--pred_ext",
                        type=str,
                        default=".pred.txt",
                        help="Extension of the predicted text files")
    parser.add_argument(
        "--n_confusions",
        type=int,
        default=10,
        help=
        "Only print n most common confusions. Defaults to 10, use -1 for all.")
    parser.add_argument(
        "--n_worst_lines",
        type=int,
        default=0,
        help="Print the n worst recognized text lines with its error")
    parser.add_argument(
        "--xlsx_output",
        type=str,
        help="Optionally write a xlsx file with the evaluation results")
    parser.add_argument("--num_threads",
                        type=int,
                        default=1,
                        help="Number of threads to use for evaluation")
    parser.add_argument(
        "--non_existing_file_handling_mode",
        type=str,
        default="error",
        help=
        "How to handle non existing .pred.txt files. Possible modes: skip, empty, error. "
        "'Skip' will simply skip the evaluation of that file (not counting it to errors). "
        "'Empty' will handle this file as would it be empty (fully checking for errors)."
        "'Error' will throw an exception if a file is not existing. This is the default behaviour."
    )
    parser.add_argument("--skip_empty_gt",
                        action="store_true",
                        default=False,
                        help="Ignore lines of the gt that are empty.")
    parser.add_argument("--no_progress_bars",
                        action="store_true",
                        help="Do not show any progress bars")
    parser.add_argument(
        "--checkpoint",
        type=str,
        default=None,
        help=
        "Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)"
    )

    # page xml specific args
    parser.add_argument("--pagexml_gt_text_index", default=0)
    parser.add_argument("--pagexml_pred_text_index", default=1)

    args = parser.parse_args()

    # check if loading a json file
    if len(args.gt) == 1 and args.gt[0].endswith("json"):
        with open(args.gt[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    print("Resolving files")
    gt_files = sorted(glob_all(args.gt))

    if args.pred:
        pred_files = sorted(glob_all(args.pred))
    else:
        pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files]
        args.pred_dataset = args.dataset

    if args.non_existing_file_handling_mode.lower() == "skip":
        non_existing_pred = [p for p in pred_files if not os.path.exists(p)]
        for f in non_existing_pred:
            idx = pred_files.index(f)
            del pred_files[idx]
            del gt_files[idx]

    text_preproc = None
    if args.checkpoint:
        with open(
                args.checkpoint if args.checkpoint.endswith(".json") else
                args.checkpoint + '.json', 'r') as f:
            checkpoint_params = json_format.Parse(f.read(), CheckpointParams())
            text_preproc = text_processor_from_proto(
                checkpoint_params.model.text_preprocessor)

    non_existing_as_empty = args.non_existing_file_handling_mode.lower(
    ) != "error "
    gt_data_set = create_dataset(
        args.dataset,
        DataSetMode.EVAL,
        texts=gt_files,
        non_existing_as_empty=non_existing_as_empty,
        args={'text_index': args.pagexml_gt_text_index},
    )
    pred_data_set = create_dataset(
        args.pred_dataset,
        DataSetMode.EVAL,
        texts=pred_files,
        non_existing_as_empty=non_existing_as_empty,
        args={'text_index': args.pagexml_pred_text_index},
    )

    evaluator = Evaluator(text_preprocessor=text_preproc,
                          skip_empty_gt=args.skip_empty_gt)
    r = evaluator.run(gt_dataset=gt_data_set,
                      pred_dataset=pred_data_set,
                      processes=args.num_threads,
                      progress_bar=not args.no_progress_bars)

    # TODO: More output
    print("Evaluation result")
    print("=================")
    print("")
    print(
        "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)"
        .format(r["avg_ler"], r["total_char_errs"], r["total_chars"],
                r["total_sync_errs"]))

    # sort descending
    print_confusions(r, args.n_confusions)

    print_worst_lines(r, gt_data_set.samples(), args.n_worst_lines)

    if args.xlsx_output:
        write_xlsx(args.xlsx_output, [{
            "prefix": "evaluation",
            "results": r,
            "gt_files": gt_files,
        }])
コード例 #2
0
ファイル: nashi_client.py プロジェクト: stweil/nashi
    def evaluate_books(self, books, models, rtl=False, mode="auto", sample=-1):
        if type(books) == str:
            books = [books]
        if type(models) == str:
            models = [models]
        results = {}
        if mode == "auto":
            with h5py.File(self.cachefile, 'r', libver='latest',
                           swmr=True) as cache:
                for b in books:
                    for p in cache[b]:
                        for s in cache[b][p]:
                            if "text" in cache[b][p][s].attrs:
                                mode = "eval"
                                break
                        if mode != "auto":
                            break
                    if mode != "auto":
                        break
            if mode == "auto":
                mode = "conf"

        if mode == "conf":
            dset = Nash5DataSet(DataSetMode.PREDICT, self.cachefile, books)
        else:
            dset = Nash5DataSet(DataSetMode.TRAIN, self.cachefile, books)
            dset.mode = DataSetMode.PREDICT  # otherwise results are randomised

        if 0 < sample < len(dset):
            delsamples = random.sample(dset._samples, len(dset) - sample)
            for s in delsamples:
                dset._samples.remove(s)

        if mode == "conf":
            #dset = dset.to_raw_input_dataset(processes=1, progress_bar=True)
            for model in models:
                if isinstance(model, str):
                    model = [model]
                predictor = MultiPredictor(checkpoints=model,
                                           data_preproc=NoopDataPreprocessor(),
                                           batch_size=1,
                                           processes=1)
                voter_params = VoterParams()
                voter_params.type = VoterParams.Type.Value(
                    "confidence_voter_default_ctc".upper())
                voter = voter_from_proto(voter_params)
                do_prediction = predictor.predict_dataset(dset,
                                                          progress_bar=True)
                avg_sentence_confidence = 0
                n_predictions = 0
                for result, sample in do_prediction:
                    n_predictions += 1
                    prediction = voter.vote_prediction_result(result)
                    avg_sentence_confidence += prediction.avg_char_probability

                results["/".join(
                    model)] = avg_sentence_confidence / n_predictions

        else:
            for model in models:
                if isinstance(model, str):
                    model = [model]

                predictor = MultiPredictor(checkpoints=model,
                                           data_preproc=NoopDataPreprocessor(),
                                           batch_size=1,
                                           processes=1)

                voter_params = VoterParams()
                voter_params.type = VoterParams.Type.Value(
                    "confidence_voter_default_ctc".upper())
                voter = voter_from_proto(voter_params)

                out_gen = predictor.predict_dataset(dset, progress_bar=True)

                preproc = self.bidi_preproc if rtl else self.txt_preproc

                pred_dset = RawDataSet(DataSetMode.EVAL,
                                       texts=preproc.apply([
                                           voter.vote_prediction_result(
                                               d[0]).sentence for d in out_gen
                                       ]))

                evaluator = Evaluator(text_preprocessor=NoopTextProcessor(),
                                      skip_empty_gt=False)
                r = evaluator.run(gt_dataset=dset,
                                  pred_dataset=pred_dset,
                                  processes=1,
                                  progress_bar=True)

                results["/".join(model)] = 1 - r["avg_ler"]
        return results
コード例 #3
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--gt",
                        nargs="+",
                        required=True,
                        help="Ground truth files (.gt.txt extension)")
    parser.add_argument(
        "--pred",
        nargs="+",
        default=None,
        help=
        "Prediction files if provided. Else files with .pred.txt are expected at the same "
        "location as the gt.")
    parser.add_argument("--pred_ext",
                        type=str,
                        default=".pred.txt",
                        help="Extension of the predicted text files")
    parser.add_argument(
        "--n_confusions",
        type=int,
        default=10,
        help=
        "Only print n most common confusions. Defaults to 10, use -1 for all.")
    parser.add_argument(
        "--n_worst_lines",
        type=int,
        default=0,
        help="Print the n worst recognized text lines with its error")
    parser.add_argument(
        "--xlsx_output",
        type=str,
        help="Optionally write a xlsx file with the evaluation results")
    parser.add_argument("--num_threads",
                        type=int,
                        default=1,
                        help="Number of threads to use for evaluation")
    parser.add_argument(
        "--non_existing_file_handling_mode",
        type=str,
        default="error",
        help=
        "How to handle non existing .pred.txt files. Possible modes: skip, empty, error. "
        "'Skip' will simply skip the evaluation of that file (not counting it to errors). "
        "'Empty' will handle this file as would it be empty (fully checking for errors)."
        "'Error' will throw an exception if a file is not existing. This is the default behaviour."
    )
    parser.add_argument("--no_progress_bars",
                        action="store_true",
                        help="Do not show any progress bars")
    parser.add_argument(
        "--checkpoint",
        type=str,
        default=None,
        help=
        "Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)"
    )

    args = parser.parse_args()

    print("Resolving files")
    gt_files = sorted(glob_all(args.gt))

    if args.pred:
        pred_files = sorted(glob_all(args.pred))
        if len(pred_files) != len(gt_files):
            raise Exception(
                "Mismatch in the number of gt and pred files: {} vs {}".format(
                    len(gt_files), len(pred_files)))
    else:
        pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files]

    if args.non_existing_file_handling_mode.lower() == "skip":
        non_existing_pred = [p for p in pred_files if not os.path.exists(p)]
        for f in non_existing_pred:
            idx = pred_files.index(f)
            del pred_files[idx]
            del gt_files[idx]

    text_preproc = None
    if args.checkpoint:
        with open(
                args.checkpoint if args.checkpoint.endswith(".json") else
                args.checkpoint + '.json', 'r') as f:
            checkpoint_params = json_format.Parse(f.read(), CheckpointParams())
            text_preproc = text_processor_from_proto(
                checkpoint_params.model.text_preprocessor)

    non_existing_as_empty = args.non_existing_file_handling_mode.lower(
    ) == "empty"
    gt_data_set = FileDataSet(texts=gt_files,
                              non_existing_as_empty=non_existing_as_empty)
    pred_data_set = FileDataSet(texts=pred_files,
                                non_existing_as_empty=non_existing_as_empty)

    evaluator = Evaluator(text_preprocessor=text_preproc)
    r = evaluator.run(gt_dataset=gt_data_set,
                      pred_dataset=pred_data_set,
                      processes=args.num_threads,
                      progress_bar=not args.no_progress_bars)

    # TODO: More output
    print("Evaluation result")
    print("=================")
    print("")
    print(
        "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)"
        .format(r["avg_ler"], r["total_char_errs"], r["total_chars"],
                r["total_sync_errs"]))

    # sort descending
    print_confusions(r, args.n_confusions)

    print_worst_lines(r, gt_files, gt_data_set.text_samples(),
                      pred_data_set.text_samples(), args.n_worst_lines)

    if args.xlsx_output:
        write_xlsx(args.xlsx_output, [{
            "prefix": "evaluation",
            "results": r,
            "gt_files": gt_files,
            "gts": gt_data_set.text_samples(),
            "preds": pred_data_set.text_samples()
        }])
コード例 #4
0
ファイル: eval.py プロジェクト: CurtLH/calamari
def main():
    parser = ArgumentParser()
    parser.add_argument("--gt",
                        nargs="+",
                        required=True,
                        help="Ground truth files (.gt.txt extension)")
    parser.add_argument(
        "--pred",
        nargs="+",
        default=None,
        help=
        "Prediction files if provided. Else files with .pred.txt are expected at the same "
        "location as the gt.")
    parser.add_argument("--pred_ext",
                        type=str,
                        default=".pred.txt",
                        help="Extension of the predicted text files")
    parser.add_argument("--n_confusions",
                        type=int,
                        default=-1,
                        help="Only print n most common confusions")
    parser.add_argument("--num_threads",
                        type=int,
                        default=1,
                        help="Number of threads to use for evaluation")

    args = parser.parse_args()

    gt_files = sorted(glob_all(args.gt))

    if args.pred:
        pred_files = sorted(glob_all(args.pred))
        if len(pred_files) != len(gt_files):
            raise Exception(
                "Mismatch in the number of gt and pred files: {} vs {}".format(
                    len(gt_files), len(pred_files)))
    else:
        pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files]

    gt_data_set = FileDataSet(texts=gt_files)
    pred_data_set = FileDataSet(texts=pred_files)

    evaluator = Evaluator()
    r = evaluator.run(gt_dataset=gt_data_set,
                      pred_dataset=pred_data_set,
                      processes=args.num_threads,
                      progress_bar=True)

    # TODO: More output
    print("Evaluation result")
    print("=================")
    print("")
    print(
        "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)"
        .format(r["avg_ler"], r["total_char_errs"], r["total_chars"],
                r["total_sync_errs"]))

    # sort descending
    if args.n_confusions != 0 and r["total_sync_errs"] > 0:
        total_percent = 0
        keys = sorted(r['confusion'].items(), key=lambda item: -item[1])
        print("{:8s} {:8s} {:8s} {:10s}".format("GT", "PRED", "COUNT",
                                                "PERCENT"))

        for i, ((gt, pred), count) in enumerate(keys):
            gt_fmt = "{" + gt + "}"
            pred_fmt = "{" + pred + "}"
            if i == args.n_confusions:
                break

            percent = count * max(len(gt), len(pred)) / r["total_sync_errs"]
            print("{:8s} {:8s} {:8d} {:10.2%}".format(gt_fmt, pred_fmt, count,
                                                      percent))
            total_percent += percent

        print("The remaining but hidden errors make up {:.2%}".format(
            1.0 - total_percent))
コード例 #5
0
ファイル: eval.py プロジェクト: AIRob/calamari
def main():
    parser = ArgumentParser()
    parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--gt", nargs="+", required=True,
                        help="Ground truth files (.gt.txt extension)")
    parser.add_argument("--pred", nargs="+", default=None,
                        help="Prediction files if provided. Else files with .pred.txt are expected at the same "
                             "location as the gt.")
    parser.add_argument("--pred_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--pred_ext", type=str, default=".pred.txt",
                        help="Extension of the predicted text files")
    parser.add_argument("--n_confusions", type=int, default=10,
                        help="Only print n most common confusions. Defaults to 10, use -1 for all.")
    parser.add_argument("--n_worst_lines", type=int, default=0,
                        help="Print the n worst recognized text lines with its error")
    parser.add_argument("--xlsx_output", type=str,
                        help="Optionally write a xlsx file with the evaluation results")
    parser.add_argument("--num_threads", type=int, default=1,
                        help="Number of threads to use for evaluation")
    parser.add_argument("--non_existing_file_handling_mode", type=str, default="error",
                        help="How to handle non existing .pred.txt files. Possible modes: skip, empty, error. "
                             "'Skip' will simply skip the evaluation of that file (not counting it to errors). "
                             "'Empty' will handle this file as would it be empty (fully checking for errors)."
                             "'Error' will throw an exception if a file is not existing. This is the default behaviour.")
    parser.add_argument("--no_progress_bars", action="store_true",
                        help="Do not show any progress bars")
    parser.add_argument("--checkpoint", type=str, default=None,
                        help="Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)")

    # page xml specific args
    parser.add_argument("--pagexml_gt_text_index", default=0)
    parser.add_argument("--pagexml_pred_text_index", default=1)


    args = parser.parse_args()

    print("Resolving files")
    gt_files = sorted(glob_all(args.gt))

    if args.pred:
        pred_files = sorted(glob_all(args.pred))
    else:
        pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files]
        args.pred_dataset = args.dataset

    if args.non_existing_file_handling_mode.lower() == "skip":
        non_existing_pred = [p for p in pred_files if not os.path.exists(p)]
        for f in non_existing_pred:
            idx = pred_files.index(f)
            del pred_files[idx]
            del gt_files[idx]

    text_preproc = None
    if args.checkpoint:
        with open(args.checkpoint if args.checkpoint.endswith(".json") else args.checkpoint + '.json', 'r') as f:
            checkpoint_params = json_format.Parse(f.read(), CheckpointParams())
            text_preproc = text_processor_from_proto(checkpoint_params.model.text_preprocessor)

    non_existing_as_empty = args.non_existing_file_handling_mode.lower() != "error "
    gt_data_set = create_dataset(
        args.dataset,
        DataSetMode.EVAL,
        texts=gt_files,
        non_existing_as_empty=non_existing_as_empty,
        args={'text_index': args.pagexml_gt_text_index},
    )
    pred_data_set = create_dataset(
        args.pred_dataset,
        DataSetMode.EVAL,
        texts=pred_files,
        non_existing_as_empty=non_existing_as_empty,
        args={'text_index': args.pagexml_pred_text_index},
    )

    evaluator = Evaluator(text_preprocessor=text_preproc)
    r = evaluator.run(gt_dataset=gt_data_set, pred_dataset=pred_data_set, processes=args.num_threads,
                      progress_bar=not args.no_progress_bars)

    # TODO: More output
    print("Evaluation result")
    print("=================")
    print("")
    print("Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)".format(
        r["avg_ler"], r["total_char_errs"], r["total_chars"], r["total_sync_errs"]))

    # sort descending
    print_confusions(r, args.n_confusions)

    print_worst_lines(r, gt_data_set.samples(), pred_data_set.text_samples(), args.n_worst_lines)

    if args.xlsx_output:
        write_xlsx(args.xlsx_output,
                   [{
                       "prefix": "evaluation",
                       "results": r,
                       "gt_files": gt_files,
                       "gts": gt_data_set.text_samples(),
                       "preds": pred_data_set.text_samples()
                   }])