예제 #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--files", nargs="+", required=True,
                        help="List of all image files with corresponding gt.txt files")
    parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--line_height", type=int, default=48,
                        help="The line height")
    parser.add_argument("--pad", type=int, default=16,
                        help="Padding (left right) of the line")

    args = parser.parse_args()

    print("Resolving files")
    image_files = glob_all(args.files)
    gt_files = [split_all_ext(p)[0] + ".gt.txt" for p in image_files]

    ds = create_dataset(
        args.dataset,
        DataSetMode.TRAIN,
        images=image_files, texts=gt_files, non_existing_as_empty=True)

    print("Loading {} files".format(len(image_files)))
    ds.load_samples(processes=1, progress_bar=True)
    images, texts = ds.train_samples(skip_empty=True)
    statistics = {
        "n_lines": len(images),
        "chars": [len(c) for c in texts],
        "widths": [img.shape[1] / img.shape[0] * args.line_height + 2 * args.pad for img in images
                   if img is not None and img.shape[0] > 0 and img.shape[1] > 0],
        "total_line_width": 0,
        "char_counts": {},
    }

    for image, text in zip(images, texts):
        for c in text:
            if c in statistics["char_counts"]:
                statistics["char_counts"][c] += 1
            else:
                statistics["char_counts"][c] = 1

    statistics["av_line_width"] = np.average(statistics["widths"])
    statistics["max_line_width"] = np.max(statistics["widths"])
    statistics["min_line_width"] = np.min(statistics["widths"])
    statistics["total_line_width"] = np.sum(statistics["widths"])

    statistics["av_chars"] = np.average(statistics["chars"])
    statistics["max_chars"] = np.max(statistics["chars"])
    statistics["min_chars"] = np.min(statistics["chars"])
    statistics["total_chars"] = np.sum(statistics["chars"])

    statistics["av_px_per_char"] = statistics["av_line_width"] / statistics["av_chars"]
    statistics["codec_size"] = len(statistics["char_counts"])

    del statistics["chars"]
    del statistics["widths"]


    print(statistics)
예제 #2
0
 def test_raw_dataset_prediction(self):
     args = PredictionAttrs()
     predictor = Predictor(checkpoint=args.checkpoint[0])
     data = create_dataset(
         DataSetType.FILE,
         DataSetMode.PREDICT,
         images=args.files,
     )
     for prediction, sample in predictor.predict_dataset(data):
         pass
예제 #3
0
    def single_evaluation(predicted_sentences):
        if len(predicted_sentences) != len(dataset):
            raise Exception("Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did "
                            "not succeed".format(len(dataset), len(predicted_sentences)))

        pred_data_set = create_dataset(
            DataSetType.RAW,
            DataSetMode.EVAL,
            texts=predicted_sentences)

        r = evaluator.run(pred_dataset=pred_data_set, progress_bar=True, processes=args.processes)

        return r
예제 #4
0
    def single_evaluation(predicted_sentences):
        if len(predicted_sentences) != len(dataset):
            raise Exception("Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did "
                            "not succeed".format(len(dataset), len(predicted_sentences)))

        pred_data_set = create_dataset(
            DataSetType.RAW,
            DataSetMode.EVAL,
            texts=predicted_sentences)

        r = evaluator.run(pred_dataset=pred_data_set, progress_bar=True, processes=args.processes)

        return r
예제 #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--eval_imgs", type=str, nargs="+", required=True,
                        help="The evaluation files")
    parser.add_argument("--eval_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--checkpoint", type=str, nargs="+", default=[],
                        help="Path to the checkpoint without file extension")
    parser.add_argument("-j", "--processes", type=int, default=1,
                        help="Number of processes to use")
    parser.add_argument("--verbose", action="store_true",
                        help="Print additional information")
    parser.add_argument("--voter", type=str, nargs="+", default=["sequence_voter", "confidence_voter_default_ctc", "confidence_voter_fuzzy_ctc"],
                        help="The voting algorithm to use. Possible values: confidence_voter_default_ctc (default), "
                             "confidence_voter_fuzzy_ctc, sequence_voter")
    parser.add_argument("--batch_size", type=int, default=10,
                        help="The batch size for prediction")
    parser.add_argument("--dump", type=str,
                        help="Dump the output as serialized pickle object")
    parser.add_argument("--no_skip_invalid_gt", action="store_true",
                        help="Do no skip invalid gt, instead raise an exception.")

    args = parser.parse_args()

    # allow user to specify json file for model definition, but remove the file extension
    # for further processing
    args.checkpoint = [(cp[:-5] if cp.endswith(".json") else cp) for cp in args.checkpoint]

    # load files
    gt_images = sorted(glob_all(args.eval_imgs))
    gt_txts = [split_all_ext(path)[0] + ".gt.txt" for path in sorted(glob_all(args.eval_imgs))]

    dataset = create_dataset(
        args.eval_dataset,
        DataSetMode.TRAIN,
        images=gt_images,
        texts=gt_txts,
        skip_invalid=not args.no_skip_invalid_gt
    )

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception("Empty dataset provided. Check your files argument (got {})!".format(args.files))

    # predict for all models
    n_models = len(args.checkpoint)
    predictor = MultiPredictor(checkpoints=args.checkpoint, batch_size=args.batch_size, processes=args.processes)
    do_prediction = predictor.predict_dataset(dataset, progress_bar=True)

    voters = []
    all_voter_sentences = []
    all_prediction_sentences = [[] for _ in range(n_models)]

    for voter in args.voter:
        # create voter
        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value(voter.upper())
        voters.append(voter_from_proto(voter_params))
        all_voter_sentences.append([])

    for prediction, sample in do_prediction:
        for sent, p in zip(all_prediction_sentences, prediction):
            sent.append(p.sentence)

        # vote results
        for voter, voter_sentences in zip(voters, all_voter_sentences):
            voter_sentences.append(voter.vote_prediction_result(prediction).sentence)

    # evaluation
    text_preproc = text_processor_from_proto(predictor.predictors[0].model_params.text_preprocessor)
    evaluator = Evaluator(text_preprocessor=text_preproc)
    evaluator.preload_gt(gt_dataset=dataset, progress_bar=True)

    def single_evaluation(predicted_sentences):
        if len(predicted_sentences) != len(dataset):
            raise Exception("Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did "
                            "not succeed".format(len(dataset), len(predicted_sentences)))

        pred_data_set = create_dataset(
            DataSetType.RAW,
            DataSetMode.EVAL,
            texts=predicted_sentences)

        r = evaluator.run(pred_dataset=pred_data_set, progress_bar=True, processes=args.processes)

        return r

    full_evaluation = {}
    for id, data in [(str(i), sent) for i, sent in enumerate(all_prediction_sentences)] + list(zip(args.voter, all_voter_sentences)):
        full_evaluation[id] = {"eval": single_evaluation(data), "data": data}

    if args.verbose:
        print(full_evaluation)

    if args.dump:
        import pickle
        with open(args.dump, 'wb') as f:
            pickle.dump({"full": full_evaluation, "gt_txts": gt_txts, "gt": dataset.text_samples()}, f)
예제 #6
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--checkpoint",
                        type=str,
                        required=True,
                        help="The checkpoint used to resume")

    # validation files
    parser.add_argument("--validation",
                        type=str,
                        nargs="+",
                        help="Validation line files used for early stopping")
    parser.add_argument(
        "--validation_text_files",
        nargs="+",
        default=None,
        help=
        "Optional list of validation GT files if they are in other directory")
    parser.add_argument(
        "--validation_extension",
        default=None,
        help="Default extension of the gt files (expected to exist in same dir)"
    )
    parser.add_argument("--validation_dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)

    # input files
    parser.add_argument(
        "--files",
        nargs="+",
        help=
        "List all image files that shall be processed. Ground truth fils with the same "
        "base name but with '.gt.txt' as extension are required at the same location"
    )
    parser.add_argument(
        "--text_files",
        nargs="+",
        default=None,
        help="Optional list of GT files if they are in other directory")
    parser.add_argument(
        "--gt_extension",
        default=None,
        help="Default extension of the gt files (expected to exist in same dir)"
    )
    parser.add_argument("--dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument(
        "--no_skip_invalid_gt",
        action="store_true",
        help="Do no skip invalid gt, instead raise an exception.")

    args = parser.parse_args()

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    if args.validation_extension is None:
        args.validation_extension = DataSetType.gt_extension(
            args.validation_dataset)

    # Training dataset
    print("Resolving input files")
    input_image_files = sorted(glob_all(args.files))
    if not args.text_files:
        gt_txt_files = [
            split_all_ext(f)[0] + args.gt_extension for f in input_image_files
        ]
    else:
        gt_txt_files = sorted(glob_all(args.text_files))
        input_image_files, gt_txt_files = keep_files_with_same_file_name(
            input_image_files, gt_txt_files)
        for img, gt in zip(input_image_files, gt_txt_files):
            if split_all_ext(os.path.basename(img))[0] != split_all_ext(
                    os.path.basename(gt))[0]:
                raise Exception(
                    "Expected identical basenames of file: {} and {}".format(
                        img, gt))

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception(
            "Some image are occurring more than once in the data set.")

    dataset = create_dataset(args.dataset,
                             DataSetMode.TRAIN,
                             images=input_image_files,
                             texts=gt_txt_files,
                             skip_invalid=not args.no_skip_invalid_gt)
    print("Found {} files in the dataset".format(len(dataset)))

    # Validation dataset
    if args.validation:
        print("Resolving validation files")
        validation_image_files = glob_all(args.validation)
        if not args.validation_text_files:
            val_txt_files = [
                split_all_ext(f)[0] + args.validation_extension
                for f in validation_image_files
            ]
        else:
            val_txt_files = sorted(glob_all(args.validation_text_files))
            validation_image_files, val_txt_files = keep_files_with_same_file_name(
                validation_image_files, val_txt_files)
            for img, gt in zip(validation_image_files, val_txt_files):
                if split_all_ext(os.path.basename(img))[0] != split_all_ext(
                        os.path.basename(gt))[0]:
                    raise Exception(
                        "Expected identical basenames of validation file: {} and {}"
                        .format(img, gt))

        if len(set(val_txt_files)) != len(val_txt_files):
            raise Exception(
                "Some validation images are occurring more than once in the data set."
            )

        validation_dataset = create_dataset(
            args.validation_dataset,
            DataSetMode.TRAIN,
            images=validation_image_files,
            texts=val_txt_files,
            skip_invalid=not args.no_skip_invalid_gt)
        print("Found {} files in the validation dataset".format(
            len(validation_dataset)))
    else:
        validation_dataset = None

    print("Resuming training")
    with open(args.checkpoint + '.json', 'r') as f:
        checkpoint_params = json_format.Parse(f.read(), CheckpointParams())

        trainer = Trainer(checkpoint_params,
                          dataset,
                          validation_dataset=validation_dataset,
                          weights=args.checkpoint)
        trainer.train(progress_bar=True)
예제 #7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--eval_imgs", type=str, nargs="+", required=True,
                        help="The evaluation files")
    parser.add_argument("--eval_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--checkpoint", type=str, nargs="+", default=[],
                        help="Path to the checkpoint without file extension")
    parser.add_argument("-j", "--processes", type=int, default=1,
                        help="Number of processes to use")
    parser.add_argument("--verbose", action="store_true",
                        help="Print additional information")
    parser.add_argument("--voter", type=str, nargs="+", default=["sequence_voter", "confidence_voter_default_ctc", "confidence_voter_fuzzy_ctc"],
                        help="The voting algorithm to use. Possible values: confidence_voter_default_ctc (default), "
                             "confidence_voter_fuzzy_ctc, sequence_voter")
    parser.add_argument("--batch_size", type=int, default=10,
                        help="The batch size for prediction")
    parser.add_argument("--dump", type=str,
                        help="Dump the output as serialized pickle object")
    parser.add_argument("--no_skip_invalid_gt", action="store_true",
                        help="Do no skip invalid gt, instead raise an exception.")

    args = parser.parse_args()

    # allow user to specify json file for model definition, but remove the file extension
    # for further processing
    args.checkpoint = [(cp[:-5] if cp.endswith(".json") else cp) for cp in args.checkpoint]

    # load files
    gt_images = sorted(glob_all(args.eval_imgs))
    gt_txts = [split_all_ext(path)[0] + ".gt.txt" for path in sorted(glob_all(args.eval_imgs))]

    dataset = create_dataset(
        args.eval_dataset,
        DataSetMode.TRAIN,
        images=gt_images,
        texts=gt_txts,
        skip_invalid=not args.no_skip_invalid_gt
    )

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception("Empty dataset provided. Check your files argument (got {})!".format(args.files))

    # predict for all models
    n_models = len(args.checkpoint)
    predictor = MultiPredictor(checkpoints=args.checkpoint, batch_size=args.batch_size, processes=args.processes)
    do_prediction = predictor.predict_dataset(dataset, progress_bar=True)

    voters = []
    all_voter_sentences = []
    all_prediction_sentences = [[] for _ in range(n_models)]

    for voter in args.voter:
        # create voter
        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value(voter.upper())
        voters.append(voter_from_proto(voter_params))
        all_voter_sentences.append([])

    for prediction, sample in do_prediction:
        for sent, p in zip(all_prediction_sentences, prediction):
            sent.append(p.sentence)

        # vote results
        for voter, voter_sentences in zip(voters, all_voter_sentences):
            voter_sentences.append(voter.vote_prediction_result(prediction).sentence)

    # evaluation
    text_preproc = text_processor_from_proto(predictor.predictors[0].model_params.text_preprocessor)
    evaluator = Evaluator(text_preprocessor=text_preproc)
    evaluator.preload_gt(gt_dataset=dataset, progress_bar=True)

    def single_evaluation(predicted_sentences):
        if len(predicted_sentences) != len(dataset):
            raise Exception("Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did "
                            "not succeed".format(len(dataset), len(predicted_sentences)))

        pred_data_set = create_dataset(
            DataSetType.RAW,
            DataSetMode.EVAL,
            texts=predicted_sentences)

        r = evaluator.run(pred_dataset=pred_data_set, progress_bar=True, processes=args.processes)

        return r

    full_evaluation = {}
    for id, data in [(str(i), sent) for i, sent in enumerate(all_prediction_sentences)] + list(zip(args.voter, all_voter_sentences)):
        full_evaluation[id] = {"eval": single_evaluation(data), "data": data}

    if args.verbose:
        print(full_evaluation)

    if args.dump:
        import pickle
        with open(args.dump, 'wb') as f:
            pickle.dump({"full": full_evaluation, "gt_txts": gt_txts, "gt": dataset.text_samples()}, f)
예제 #8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--files",
        nargs="+",
        required=True,
        help="List of all image files with corresponding gt.txt files")
    parser.add_argument("--dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument("--line_height",
                        type=int,
                        default=48,
                        help="The line height")
    parser.add_argument("--pad",
                        type=int,
                        default=16,
                        help="Padding (left right) of the line")

    args = parser.parse_args()

    print("Resolving files")
    image_files = glob_all(args.files)
    gt_files = [split_all_ext(p)[0] + ".gt.txt" for p in image_files]

    ds = create_dataset(args.dataset,
                        DataSetMode.TRAIN,
                        images=image_files,
                        texts=gt_files,
                        non_existing_as_empty=True)

    print("Loading {} files".format(len(image_files)))
    ds.load_samples(processes=1, progress_bar=True)
    images, texts = ds.train_samples(skip_empty=True)
    statistics = {
        "n_lines":
        len(images),
        "chars": [len(c) for c in texts],
        "widths": [
            img.shape[1] / img.shape[0] * args.line_height + 2 * args.pad
            for img in images
            if img is not None and img.shape[0] > 0 and img.shape[1] > 0
        ],
        "total_line_width":
        0,
        "char_counts": {},
    }

    for image, text in zip(images, texts):
        for c in text:
            if c in statistics["char_counts"]:
                statistics["char_counts"][c] += 1
            else:
                statistics["char_counts"][c] = 1

    statistics["av_line_width"] = np.average(statistics["widths"])
    statistics["max_line_width"] = np.max(statistics["widths"])
    statistics["min_line_width"] = np.min(statistics["widths"])
    statistics["total_line_width"] = np.sum(statistics["widths"])

    statistics["av_chars"] = np.average(statistics["chars"])
    statistics["max_chars"] = np.max(statistics["chars"])
    statistics["min_chars"] = np.min(statistics["chars"])
    statistics["total_chars"] = np.sum(statistics["chars"])

    statistics["av_px_per_char"] = statistics["av_line_width"] / statistics[
        "av_chars"]
    statistics["codec_size"] = len(statistics["char_counts"])

    del statistics["chars"]
    del statistics["widths"]

    print(statistics)
예제 #9
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--checkpoint", type=str, required=True,
                        help="The checkpoint used to resume")

    # validation files
    parser.add_argument("--validation", type=str, nargs="+",
                        help="Validation line files used for early stopping")
    parser.add_argument("--validation_text_files", nargs="+", default=None,
                        help="Optional list of validation GT files if they are in other directory")
    parser.add_argument("--validation_extension", default=None,
                        help="Default extension of the gt files (expected to exist in same dir)")
    parser.add_argument("--validation_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)

    # input files
    parser.add_argument("--files", nargs="+",
                        help="List all image files that shall be processed. Ground truth fils with the same "
                             "base name but with '.gt.txt' as extension are required at the same location")
    parser.add_argument("--text_files", nargs="+", default=None,
                        help="Optional list of GT files if they are in other directory")
    parser.add_argument("--gt_extension", default=None,
                        help="Default extension of the gt files (expected to exist in same dir)")
    parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--no_skip_invalid_gt", action="store_true",
                        help="Do no skip invalid gt, instead raise an exception.")

    args = parser.parse_args()

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    if args.validation_extension is None:
        args.validation_extension = DataSetType.gt_extension(args.validation_dataset)

    # Training dataset
    print("Resolving input files")
    input_image_files = sorted(glob_all(args.files))
    if not args.text_files:
        gt_txt_files = [split_all_ext(f)[0] + args.gt_extension for f in input_image_files]
    else:
        gt_txt_files = sorted(glob_all(args.text_files))
        input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files)
        for img, gt in zip(input_image_files, gt_txt_files):
            if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                raise Exception("Expected identical basenames of file: {} and {}".format(img, gt))

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception("Some image are occurring more than once in the data set.")

    dataset = create_dataset(
        args.dataset,
        DataSetMode.TRAIN,
        images=input_image_files,
        texts=gt_txt_files,
        skip_invalid=not args.no_skip_invalid_gt
    )
    print("Found {} files in the dataset".format(len(dataset)))

    # Validation dataset
    if args.validation:
        print("Resolving validation files")
        validation_image_files = glob_all(args.validation)
        if not args.validation_text_files:
            val_txt_files = [split_all_ext(f)[0] + args.validation_extension for f in validation_image_files]
        else:
            val_txt_files = sorted(glob_all(args.validation_text_files))
            validation_image_files, val_txt_files = keep_files_with_same_file_name(validation_image_files, val_txt_files)
            for img, gt in zip(validation_image_files, val_txt_files):
                if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                    raise Exception("Expected identical basenames of validation file: {} and {}".format(img, gt))

        if len(set(val_txt_files)) != len(val_txt_files):
            raise Exception("Some validation images are occurring more than once in the data set.")

        validation_dataset = create_dataset(
            args.validation_dataset,
            DataSetMode.TRAIN,
            images=validation_image_files,
            texts=val_txt_files,
            skip_invalid=not args.no_skip_invalid_gt)
        print("Found {} files in the validation dataset".format(len(validation_dataset)))
    else:
        validation_dataset = None

    print("Resuming training")
    with open(args.checkpoint + '.json', 'r') as f:
        checkpoint_params = json_format.Parse(f.read(), CheckpointParams())

        trainer = Trainer(checkpoint_params, dataset,
                          validation_dataset=validation_dataset,
                          weights=args.checkpoint)
        trainer.train(progress_bar=True)