Exemplo n.º 1
0
 def create_voter(self, data_params: 'DataParams') -> MultiModelVoter:
     # Cut non text processors (first two)
     post_proc = [Data.data_processor_factory().create_sequence(
         data.params().post_processors_.sample_processors[2:], data.params(), PipelineMode.Prediction) for
         data in self.datas]
     pre_proc = Data.data_processor_factory().create_sequence(
         self.data.params().pre_processors_.sample_processors, self.data.params(),
         PipelineMode.Prediction)
     out_to_in_transformer = OutputToInputTransformer(pre_proc)
     return CalamariMultiModelVoter(self.voter_params, self.datas, post_proc, out_to_in_transformer)
Exemplo n.º 2
0
def main(args: EvalArgs):
    # Local imports (imports that require tensorflow)
    from calamari_ocr.ocr.scenario import CalamariScenario
    from calamari_ocr.ocr.dataset.data import Data
    from calamari_ocr.ocr.evaluator import Evaluator

    if args.checkpoint:
        saved_model = SavedCalamariModel(args.checkpoint, auto_update=True)
        trainer_params = CalamariScenario.trainer_cls().params_cls().from_dict(saved_model.dict)
        data_params = trainer_params.scenario.data
    else:
        data_params = Data.default_params()

    data = Data(data_params)

    pred_data = args.pred if args.pred is not None else args.gt.to_prediction()
    evaluator = Evaluator(args.evaluator, data=data)
    evaluator.preload_gt(gt_dataset=args.gt)
    r = evaluator.run(gt_dataset=args.gt, pred_dataset=pred_data)

    # TODO: More output
    print("Evaluation result")
    print("=================")
    print("")
    print(
        "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)".format(
            r["avg_ler"], r["total_char_errs"], r["total_chars"], r["total_sync_errs"]
        )
    )

    # sort descending
    print_confusions(r, args.n_confusions)

    samples = data.create_pipeline(evaluator.params.setup, args.gt).reader().samples()
    print_worst_lines(r, samples, args.n_worst_lines)

    if args.xlsx_output:
        write_xlsx(
            args.xlsx_output,
            [
                {
                    "prefix": "evaluation",
                    "results": r,
                    "gt_files": [s["id"] for s in samples],
                }
            ],
        )

    return r
Exemplo n.º 3
0
def get_preproc_image():
    data_params = Data.default_params()
    data_params.skip_invalid_gt = False
    data_params.pre_proc.run_parallel = False
    data_params.pre_proc.processors = data_params.pre_proc.processors[:-1]
    for p in data_params.pre_proc.processors_of_type(FinalPreparationProcessorParams):
        p.pad = 0
    post_init(data_params)
    pl = Data(data_params).create_pipeline(DataPipelineParams, None)
    pl.mode = PipelineMode.PREDICTION
    preproc = data_params.pre_proc.create(pl)

    def pp(image):
        its = InputSample(
            image, None, SampleMeta("001", fold_id="01")
        ).to_input_target_sample()
        s = preproc.apply_on_sample(its)
        return s.inputs

    return pp
Exemplo n.º 4
0
def get_preproc_text(rtl=False):
    data_params = Data.default_params()
    data_params.skip_invalid_gt = False
    data_params.pre_proc.run_parallel = False

    if rtl:
        for p in data_params.pre_proc.processors_of_type(BidiTextProcessorParams):
            p.bidi_direction = BidiDirection.RTL
    post_init(data_params)

    pl = Data(data_params).create_pipeline(DataPipelineParams, None)
    pl.mode = PipelineMode.TARGETS
    preproc = data_params.pre_proc.create(pl)

    def pp(text):
        its = InputSample(
            None, text, SampleMeta("001", fold_id="01")
        ).to_input_target_sample()
        s = preproc.apply_on_sample(its)
        return s.targets

    return pp
Exemplo n.º 5
0
def run(args):
    # local imports (to prevent tensorflow from being imported to early)
    from calamari_ocr.ocr.scenario import Scenario
    from calamari_ocr.ocr.dataset.data import Data

    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                if key == 'dataset' or key == 'validation_dataset':
                    setattr(args, key, DataSetType.from_string(value))
                else:
                    setattr(args, key, value)

    check_train_args(args)

    if args.output_dir is not None:
        args.output_dir = os.path.abspath(args.output_dir)
        setup_log(args.output_dir, append=False)

    # parse whitelist
    whitelist = args.whitelist
    if len(whitelist) == 1:
        whitelist = list(whitelist[0])

    whitelist_files = glob_all(args.whitelist_files)
    for f in whitelist_files:
        with open(f) as txt:
            whitelist += list(txt.read())

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    if args.validation_extension is None:
        args.validation_extension = DataSetType.gt_extension(
            args.validation_dataset)

    if args.dataset == DataSetType.GENERATED_LINE or args.validation_dataset == DataSetType.GENERATED_LINE:
        if args.text_generator_params is not None:
            with open(args.text_generator_params, 'r') as f:
                args.text_generator_params = TextGeneratorParams.from_json(
                    f.read())
        else:
            args.text_generator_params = TextGeneratorParams()

        if args.line_generator_params is not None:
            with open(args.line_generator_params, 'r') as f:
                args.line_generator_params = LineGeneratorParams.from_json(
                    f.read())
        else:
            args.line_generator_params = LineGeneratorParams()

    dataset_args = FileDataReaderArgs(
        line_generator_params=args.line_generator_params,
        text_generator_params=args.text_generator_params,
        pad=args.dataset_pad,
        text_index=args.pagexml_text_index,
    )

    params: TrainerParams = Scenario.default_trainer_params()

    # =================================================================================================================
    # Data Params
    data_params: DataParams = params.scenario_params.data_params
    data_params.train = PipelineParams(
        type=args.dataset,
        skip_invalid=not args.no_skip_invalid_gt,
        remove_invalid=True,
        files=args.files,
        text_files=args.text_files,
        gt_extension=args.gt_extension,
        data_reader_args=dataset_args,
        batch_size=args.batch_size,
        num_processes=args.num_threads,
    )
    if args.validation_split_ratio:
        if args.validation is not None:
            raise ValueError("Set either validation_split_ratio or validation")
        if not 0 < args.validation_split_ratio < 1:
            raise ValueError("validation_split_ratio must be in (0, 1)")

        # resolve all files so we can split them
        data_params.train = data_params.train.prepare_for_mode(
            PipelineMode.Training)
        n = int(args.validation_split_ratio * len(data_params.train.files))
        logger.info(
            f"Splitting training and validation files with ratio {args.validation_split_ratio}: "
            f"{n}/{len(data_params.train.files) - n} for validation/training.")
        indices = list(range(len(data_params.train.files)))
        shuffle(indices)
        all_files = data_params.train.files
        all_text_files = data_params.train.text_files

        # split train and val img/gt files. Use train settings
        data_params.train.files = [all_files[i] for i in indices[:n]]
        if all_text_files is not None:
            assert (len(all_text_files) == len(all_files))
            data_params.train.text_files = [
                all_text_files[i] for i in indices[:n]
            ]
        data_params.val = PipelineParams(
            type=args.dataset,
            skip_invalid=not args.no_skip_invalid_gt,
            remove_invalid=True,
            files=[all_files[i] for i in indices[n:]],
            text_files=[all_text_files[i] for i in indices[n:]]
            if data_params.train.text_files is not None else None,
            gt_extension=args.gt_extension,
            data_reader_args=dataset_args,
            batch_size=args.batch_size,
            num_processes=args.num_threads,
        )
    elif args.validation:
        data_params.val = PipelineParams(
            type=args.validation_dataset,
            files=args.validation,
            text_files=args.validation_text_files,
            skip_invalid=not args.no_skip_invalid_gt,
            gt_extension=args.validation_extension,
            data_reader_args=dataset_args,
            batch_size=args.batch_size,
            num_processes=args.num_threads,
        )
    else:
        data_params.val = None

    data_params.pre_processors_ = SamplePipelineParams(run_parallel=True)
    data_params.post_processors_.run_parallel = SamplePipelineParams(
        run_parallel=False,
        sample_processors=[
            DataProcessorFactoryParams(ReshapeOutputsProcessor.__name__),
            DataProcessorFactoryParams(CTCDecoderProcessor.__name__),
        ])

    for p in args.data_preprocessing:
        p_p = Data.data_processor_factory().processors[p].default_params()
        if 'pad' in p_p:
            p_p['pad'] = args.pad
        data_params.pre_processors_.sample_processors.append(
            DataProcessorFactoryParams(p, INPUT_PROCESSOR, p_p))

    # Text pre processing (reading)
    data_params.pre_processors_.sample_processors.extend([
        DataProcessorFactoryParams(
            TextNormalizer.__name__, TARGETS_PROCESSOR,
            {'unicode_normalization': args.text_normalization}),
        DataProcessorFactoryParams(
            TextRegularizer.__name__, TARGETS_PROCESSOR, {
                'replacements':
                default_text_regularizer_replacements(args.text_regularization)
            }),
        DataProcessorFactoryParams(StripTextProcessor.__name__,
                                   TARGETS_PROCESSOR)
    ])

    # Text post processing (prediction)
    data_params.post_processors_.sample_processors.extend([
        DataProcessorFactoryParams(
            TextNormalizer.__name__, TARGETS_PROCESSOR,
            {'unicode_normalization': args.text_normalization}),
        DataProcessorFactoryParams(
            TextRegularizer.__name__, TARGETS_PROCESSOR, {
                'replacements':
                default_text_regularizer_replacements(args.text_regularization)
            }),
        DataProcessorFactoryParams(StripTextProcessor.__name__,
                                   TARGETS_PROCESSOR)
    ])
    if args.bidi_dir:
        data_params.pre_processors_.sample_processors.append(
            DataProcessorFactoryParams(BidiTextProcessor.__name__,
                                       TARGETS_PROCESSOR,
                                       {'bidi_direction': args.bidi_dir}))
        data_params.post_processors_.sample_processors.append(
            DataProcessorFactoryParams(BidiTextProcessor.__name__,
                                       TARGETS_PROCESSOR,
                                       {'bidi_direction': args.bidi_dir}))

    data_params.pre_processors_.sample_processors.extend([
        DataProcessorFactoryParams(AugmentationProcessor.__name__,
                                   {PipelineMode.Training},
                                   {'augmenter_type': 'simple'}),
        DataProcessorFactoryParams(PrepareSampleProcessor.__name__,
                                   INPUT_PROCESSOR),
    ])

    data_params.data_aug_params = DataAugmentationAmount.from_factor(
        args.n_augmentations)
    data_params.line_height_ = args.line_height

    # =================================================================================================================
    # Trainer Params
    params.calc_ema = args.ema_weights
    params.verbose = args.train_verbose
    params.force_eager = args.debug
    params.skip_model_load_test = not args.debug
    params.scenario_params.debug_graph_construction = args.debug
    params.epochs = args.epochs
    params.samples_per_epoch = int(
        args.samples_per_epoch) if args.samples_per_epoch >= 1 else -1
    params.scale_epoch_size = abs(
        args.samples_per_epoch) if args.samples_per_epoch < 1 else 1
    params.skip_load_model_test = True
    params.scenario_params.export_frozen = False
    params.checkpoint_save_freq_ = args.checkpoint_frequency if args.checkpoint_frequency >= 0 else args.early_stopping_frequency
    params.checkpoint_dir = args.output_dir
    params.test_every_n = args.display
    params.skip_invalid_gt = not args.no_skip_invalid_gt
    params.data_aug_retrain_on_original = not args.only_train_on_augmented
    params.use_training_as_validation = args.use_train_as_val
    if args.seed > 0:
        params.random_seed = args.seed

    params.optimizer_params.clip_grad = args.gradient_clipping_norm
    params.codec_whitelist = whitelist
    params.keep_loaded_codec = args.keep_loaded_codec
    params.preload_training = not args.train_data_on_the_fly
    params.preload_validation = not args.validation_data_on_the_fly
    params.warmstart_params.model = args.weights

    params.auto_compute_codec = not args.no_auto_compute_codec
    params.progress_bar = not args.no_progress_bars

    params.early_stopping_params.frequency = args.early_stopping_frequency
    params.early_stopping_params.upper_threshold = 0.9
    params.early_stopping_params.lower_threshold = 1.0 - args.early_stopping_at_accuracy
    params.early_stopping_params.n_to_go = args.early_stopping_nbest
    params.early_stopping_params.best_model_name = ''
    params.early_stopping_params.best_model_output_dir = args.early_stopping_best_model_output_dir
    params.scenario_params.default_serve_dir_ = f'{args.early_stopping_best_model_prefix}.ckpt.h5'
    params.scenario_params.trainer_params_filename_ = f'{args.early_stopping_best_model_prefix}.ckpt.json'

    # =================================================================================================================
    # Model params
    params_from_definition_string(args.network, params)
    params.scenario_params.model_params.ensemble = args.ensemble
    params.scenario_params.model_params.masking_mode = args.masking_mode

    scenario = Scenario(params.scenario_params)
    trainer = scenario.create_trainer(params)
    trainer.train()
Exemplo n.º 6
0
def setup_train_args(parser, omit=None):
    # required params for args
    from calamari_ocr.ocr.dataset.data import Data

    if omit is None:
        omit = []

    parser.add_argument('--version',
                        action='version',
                        version='%(prog)s v' + __version__)

    if "files" not in omit:
        parser.add_argument(
            "--files",
            nargs="+",
            default=[],
            help=
            "List all image files that shall be processed. Ground truth fils with the same "
            "base name but with '.gt.txt' as extension are required at the same location"
        )
        parser.add_argument(
            "--text_files",
            nargs="+",
            default=None,
            help="Optional list of GT files if they are in other directory")
        parser.add_argument(
            "--gt_extension",
            default=None,
            help=
            "Default extension of the gt files (expected to exist in same dir)"
        )
        parser.add_argument("--dataset",
                            type=DataSetType.from_string,
                            choices=list(DataSetType),
                            default=DataSetType.FILE)

    parser.add_argument(
        "--train_data_on_the_fly",
        action='store_true',
        default=False,
        help=
        'Instead of preloading all data during the training, load the data on the fly. '
        'This is slower, but might be required for limited RAM or large dataset'
    )

    parser.add_argument(
        "--seed",
        type=int,
        default="0",
        help=
        "Seed for random operations. If negative or zero a 'random' seed is used"
    )
    parser.add_argument(
        "--network",
        type=str,
        default="cnn=40:3x3,pool=2x2,cnn=60:3x3,pool=2x2,lstm=200,dropout=0.5",
        help="The network structure")
    parser.add_argument("--line_height",
                        type=int,
                        default=48,
                        help="The line height")
    parser.add_argument("--pad",
                        type=int,
                        default=16,
                        help="Padding (left right) of the line")
    parser.add_argument("--num_threads",
                        type=int,
                        default=1,
                        help="The number of threads to use for all operations")
    parser.add_argument(
        "--display",
        type=int,
        default=1,
        help=
        "Frequency of how often an output shall occur during training [epochs]"
    )
    parser.add_argument("--batch_size",
                        type=int,
                        default=1,
                        help="The batch size to use for training")
    parser.add_argument("--ema_weights",
                        action="store_true",
                        default=False,
                        help="Use exponentially averaged weights")
    parser.add_argument(
        "--checkpoint_frequency",
        type=int,
        default=-1,
        help=
        "The frequency how often to write checkpoints during training [epochs]"
        "If -1 (default), the early_stopping_frequency will be used. If 0 no checkpoints are written"
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=100,
        help="The number of iterations for training. "
        "If using early stopping, this is the maximum number of iterations")
    parser.add_argument(
        "--samples_per_epoch",
        type=float,
        default=-1,
        help=
        "The number of samples to process per epoch. By default the size of the training dataset."
        "If in (0,1) it is relative to the dataset size")
    parser.add_argument(
        "--early_stopping_at_accuracy",
        type=float,
        default=1.0,
        help="Stop training if the early stopping accuracy reaches this value")
    parser.add_argument(
        "--no_skip_invalid_gt",
        action="store_true",
        help="Do no skip invalid gt, instead raise an exception.")
    parser.add_argument("--no_progress_bars",
                        action="store_true",
                        help="Do not show any progress bars")

    if "output_dir" not in omit:
        parser.add_argument(
            "--output_dir",
            type=str,
            default="",
            help="Default directory where to store checkpoints and models")
    if "output_model_prefix" not in omit:
        parser.add_argument("--output_model_prefix",
                            type=str,
                            default="model_",
                            help="Prefix for storing checkpoints and models")

    parser.add_argument(
        "--bidi_dir",
        type=str,
        default=None,
        choices=["ltr", "rtl", "auto"],
        help=
        "The default text direction when preprocessing bidirectional text. Supported values "
        "are 'auto' to automatically detect the direction, 'ltr' and 'rtl' for left-to-right and "
        "right-to-left, respectively")

    if "weights" not in omit:
        parser.add_argument("--weights",
                            type=str,
                            default=None,
                            help="Load network weights from the given file.")

    parser.add_argument(
        "--no_auto_compute_codec",
        action='store_true',
        default=False,
        help="Do not compute the codec automatically. See also whitelist")
    parser.add_argument(
        "--whitelist_files",
        type=str,
        nargs="+",
        default=[],
        help=
        "Whitelist of txt files that may not be removed on restoring a model")
    parser.add_argument(
        "--whitelist",
        type=str,
        nargs="+",
        default=[],
        help=
        "Whitelist of characters that may not be removed on restoring a model. "
        "For large dataset you can use this to skip the automatic codec computation "
        "(see --no_auto_compute_codec)")
    parser.add_argument(
        "--keep_loaded_codec",
        action='store_true',
        default=False,
        help="Fully include the codec of the loaded model to the new codec")
    parser.add_argument("--ensemble",
                        type=int,
                        default=-1,
                        help="Number of voting models")
    parser.add_argument("--masking_mode",
                        type=int,
                        default=0,
                        help="Do not use")  # TODO: remove

    # clipping
    parser.add_argument("--gradient_clipping_norm",
                        type=float,
                        default=5,
                        help="Clipping constant of the norm of the gradients.")

    # early stopping
    if "validation" not in omit:
        parser.add_argument(
            "--validation",
            type=str,
            nargs="+",
            help="Validation line files used for early stopping")
        parser.add_argument(
            "--validation_text_files",
            nargs="+",
            default=None,
            help=
            "Optional list of validation GT files if they are in other directory"
        )
        parser.add_argument(
            "--validation_extension",
            default=None,
            help=
            "Default extension of the gt files (expected to exist in same dir). "
            "By default same as gt_extension")
        parser.add_argument(
            "--validation_dataset",
            type=DataSetType.from_string,
            choices=list(DataSetType),
            default=None,
            help=
            "Default validation data set type. By default same as --dataset")
        parser.add_argument("--use_train_as_val",
                            action='store_true',
                            default=False)
        parser.add_argument(
            "--validation_split_ratio",
            type=float,
            default=None,
            help=
            "Use n percent of the training dataset for validation. Can not be used with --validation"
        )

    parser.add_argument(
        "--validation_data_on_the_fly",
        action='store_true',
        default=False,
        help=
        'Instead of preloading all data during the training, load the data on the fly. '
        'This is slower, but might be required for limited RAM or large dataset'
    )

    parser.add_argument("--early_stopping_frequency",
                        type=int,
                        default=1,
                        help="The frequency of early stopping [epochs].")
    parser.add_argument(
        "--early_stopping_nbest",
        type=int,
        default=5,
        help=
        "The number of models that must be worse than the current best model to stop"
    )
    if "early_stopping_best_model_prefix" not in omit:
        parser.add_argument(
            "--early_stopping_best_model_prefix",
            type=str,
            default="best",
            help="The prefix of the best model using early stopping")
    if "early_stopping_best_model_output_dir" not in omit:
        parser.add_argument(
            "--early_stopping_best_model_output_dir",
            type=str,
            default=None,
            help="Path where to store the best model. Default is output_dir")
    parser.add_argument(
        "--n_augmentations",
        type=float,
        default=0,
        help=
        "Amount of data augmentation per line (done before training). If this number is < 1 "
        "the amount is relative.")
    parser.add_argument(
        "--only_train_on_augmented",
        action="store_true",
        default=False,
        help=
        "When training with augmentations usually the model is retrained in a second run with "
        "only the non augmented data. This will take longer. Use this flag to disable this "
        "behavior.")

    # text normalization/regularization
    parser.add_argument("--text_regularization",
                        type=str,
                        nargs="+",
                        default=["extended"],
                        help="Text regularization to apply.")
    parser.add_argument(
        "--text_normalization",
        type=str,
        default="NFC",
        help="Unicode text normalization to apply. Defaults to NFC")
    parser.add_argument(
        "--data_preprocessing",
        nargs="+",
        type=str,
        choices=[
            k for k, p in Data.data_processor_factory().processors.items()
            if issubclass(p, ImageProcessor)
        ],
        default=[p.name for p in default_image_processors()])

    # text/line generation params (loaded from json files)
    parser.add_argument("--text_generator_params", type=str, default=None)
    parser.add_argument("--line_generator_params", type=str, default=None)

    # additional dataset args
    parser.add_argument("--dataset_pad", default=None, nargs='+', type=int)
    parser.add_argument("--pagexml_text_index", default=0)

    parser.add_argument("--debug", action='store_true')
    parser.add_argument("--train_verbose", default=1)
Exemplo n.º 7
0
def main(args=None):
    parser = PAIArgumentParser()
    parser.add_argument('--version', action='version', version='%(prog)s v' + __version__)

    parser.add_argument("--n_cols", type=int, default=1)
    parser.add_argument("--n_rows", type=int, default=5)
    parser.add_argument("--select", type=int, nargs="+", default=[])

    parser.add_argument("--preload", action='store_true', help='Simulate preloading')
    parser.add_argument("--as_validation", action='store_true', help="Access as validation instead of training data.")
    parser.add_argument("--n_augmentations", type=float, default=0)
    parser.add_argument("--no_plot", action='store_true', help='This parameter is for testing only')

    parser.add_root_argument("data", DataWrapper)
    args = parser.parse_args(args=args)

    data_wrapper: DataWrapper = args.data
    data_params = data_wrapper.data
    data_params.pre_proc.run_parallel = False
    data_params.pre_proc.erase_all(PrepareSampleProcessorParams)
    for p in data_params.pre_proc.processors_of_type(AugmentationProcessorParams):
        p.n_augmentations = args.n_augmentations
    data_params.__post_init__()
    data_wrapper.pipeline.mode = PipelineMode.EVALUATION if args.as_validation else PipelineMode.TRAINING
    data_wrapper.gen.prepare_for_mode(data_wrapper.pipeline.mode)

    data = Data(data_params)
    if len(args.select) == 0:
        args.select = list(range(len(data_wrapper.gen)))
    else:
        try:
            data_wrapper.gen.select(args.select)
        except NotImplementedError:
            logger.warning(f"Selecting is not supported for a data generator of type {type(data_wrapper.gen)}. "
                           f"Resuming without selection.")
    data_pipeline = data.create_pipeline(data_wrapper.pipeline, data_wrapper.gen)
    if args.preload:
        data_pipeline = data_pipeline.as_preloaded()

    if args.no_plot:
        with data_pipeline as dataset:
            list(zip(args.select, dataset.generate_input_samples(auto_repeat=False)))
        return

    import matplotlib.pyplot as plt
    f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all')
    row, col = 0, 0
    with data_pipeline as dataset:
        for i, (id, sample) in enumerate(zip(args.select, dataset.generate_input_samples(auto_repeat=False))):
            line, text, params = sample.inputs, sample.targets, sample.meta
            if args.n_cols == 1:
                ax[row].imshow(line.transpose())
                ax[row].set_title("ID: {}\n{}".format(id, text))
            else:
                ax[row, col].imshow(line.transpose())
                ax[row, col].set_title("ID: {}\n{}".format(id, text))

            row += 1
            if row == args.n_rows:
                row = 0
                col += 1

            if col == args.n_cols or i == len(dataset) - 1:
                plt.show()
                f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all')
                row, col = 0, 0

    plt.show()
Exemplo n.º 8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--version',
                        action='version',
                        version='%(prog)s v' + __version__)
    parser.add_argument(
        "--files",
        nargs="+",
        help=
        "List all image files that shall be processed. Ground truth fils with the same "
        "base name but with '.gt.txt' as extension are required at the same location",
        required=True)
    parser.add_argument(
        "--text_files",
        nargs="+",
        default=None,
        help="Optional list of GT files if they are in other directory")
    parser.add_argument(
        "--gt_extension",
        default=None,
        help="Default extension of the gt files (expected to exist in same dir)"
    )
    parser.add_argument("--dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument("--line_height",
                        type=int,
                        default=48,
                        help="The line height")
    parser.add_argument("--pad",
                        type=int,
                        default=16,
                        help="Padding (left right) of the line")
    parser.add_argument("--processes",
                        type=int,
                        default=1,
                        help="The number of threads to use for all operations")

    parser.add_argument("--n_cols", type=int, default=1)
    parser.add_argument("--n_rows", type=int, default=5)
    parser.add_argument("--select", type=int, nargs="+", default=[])

    # text normalization/regularization
    parser.add_argument(
        "--n_augmentations",
        type=float,
        default=0,
        help=
        "Amount of data augmentation per line (done before training). If this number is < 1 "
        "the amount is relative.")
    parser.add_argument("--text_regularization",
                        type=str,
                        nargs="+",
                        default=["extended"],
                        help="Text regularization to apply.")
    parser.add_argument(
        "--text_normalization",
        type=str,
        default="NFC",
        help="Unicode text normalization to apply. Defaults to NFC")
    parser.add_argument(
        "--data_preprocessing",
        nargs="+",
        type=str,
        choices=[
            k for k, p in Data.data_processor_factory().processors.items()
            if issubclass(p, ImageProcessor)
        ],
        default=[p.name for p in default_image_processors()])
    parser.add_argument(
        "--bidi_dir",
        type=str,
        default=None,
        choices=["ltr", "rtl", "auto"],
        help=
        "The default text direction when preprocessing bidirectional text. Supported values "
        "are 'auto' to automatically detect the direction, 'ltr' and 'rtl' for left-to-right and "
        "right-to-left, respectively")

    parser.add_argument("--preload",
                        action='store_true',
                        help='Simulate preloading')
    parser.add_argument("--as_validation",
                        action='store_true',
                        help="Access as validation instead of training data.")

    args = parser.parse_args()

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    dataset_args = FileDataReaderArgs(pad=args.pad, )

    data_params: DataParams = Data.get_default_params()
    data_params.train = PipelineParams(
        type=args.dataset,
        remove_invalid=True,
        files=args.files,
        text_files=args.text_files,
        gt_extension=args.gt_extension,
        data_reader_args=dataset_args,
        batch_size=1,
        num_processes=args.processes,
    )
    data_params.val = data_params.train

    data_params.pre_processors_ = SamplePipelineParams(run_parallel=True)
    data_params.post_processors_ = SamplePipelineParams(run_parallel=True)
    for p in args.data_preprocessing:
        p_p = Data.data_processor_factory().processors[p].default_params()
        if 'pad' in p_p:
            p_p['pad'] = args.pad
        data_params.pre_processors_.sample_processors.append(
            DataProcessorFactoryParams(p, INPUT_PROCESSOR, p_p))

    # Text pre processing (reading)
    data_params.pre_processors_.sample_processors.extend([
        DataProcessorFactoryParams(
            TextNormalizer.__name__, TARGETS_PROCESSOR,
            {'unicode_normalization': args.text_normalization}),
        DataProcessorFactoryParams(
            TextRegularizer.__name__, TARGETS_PROCESSOR, {
                'replacements':
                default_text_regularizer_replacements(args.text_regularization)
            }),
        DataProcessorFactoryParams(StripTextProcessor.__name__,
                                   TARGETS_PROCESSOR)
    ])

    # Text post processing (prediction)
    data_params.post_processors_.sample_processors.extend([
        DataProcessorFactoryParams(
            TextNormalizer.__name__, TARGETS_PROCESSOR,
            {'unicode_normalization': args.text_normalization}),
        DataProcessorFactoryParams(
            TextRegularizer.__name__, TARGETS_PROCESSOR, {
                'replacements':
                default_text_regularizer_replacements(args.text_regularization)
            }),
        DataProcessorFactoryParams(StripTextProcessor.__name__,
                                   TARGETS_PROCESSOR)
    ])
    if args.bidi_dir:
        data_params.pre_processors_.sample_processors.append(
            DataProcessorFactoryParams(BidiTextProcessor.__name__,
                                       TARGETS_PROCESSOR,
                                       {'bidi_direction': args.bidi_dir}))
        data_params.post_processors_.sample_processors.append(
            DataProcessorFactoryParams(BidiTextProcessor.__name__,
                                       TARGETS_PROCESSOR,
                                       {'bidi_direction': args.bidi_dir}))

    data_params.pre_processors_.sample_processors.extend([
        DataProcessorFactoryParams(AugmentationProcessor.__name__,
                                   {PipelineMode.Training},
                                   {'augmenter_type': 'simple'}),
        # DataProcessorFactoryParams(PrepareSampleProcessor.__name__),  # NOT THIS, since, we want to access raw input
    ])

    data_params.data_aug_params = DataAugmentationAmount.from_factor(
        args.n_augmentations)
    data_params.line_height_ = args.line_height

    data = Data(data_params)
    data_pipeline = data.get_val_data(
    ) if args.as_validation else data.get_train_data()
    if not args.preload:
        reader: DataReader = data_pipeline.reader()
        if len(args.select) == 0:
            args.select = range(len(reader))
        else:
            reader._samples = [reader.samples()[i] for i in args.select]
    else:
        data.preload()
        data_pipeline = data.get_val_data(
        ) if args.as_validation else data.get_train_data()
        samples = data_pipeline.samples
        if len(args.select) == 0:
            args.select = range(len(samples))
        else:
            data_pipeline.samples = [samples[i] for i in args.select]

    f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all')
    row, col = 0, 0
    with data_pipeline as dataset:
        for i, (id, sample) in enumerate(
                zip(args.select,
                    dataset.generate_input_samples(auto_repeat=False))):
            line, text, params = sample
            if args.n_cols == 1:
                ax[row].imshow(line.transpose())
                ax[row].set_title("ID: {}\n{}".format(id, text))
            else:
                ax[row, col].imshow(line.transpose())
                ax[row, col].set_title("ID: {}\n{}".format(id, text))

            row += 1
            if row == args.n_rows:
                row = 0
                col += 1

            if col == args.n_cols or i == len(dataset) - 1:
                plt.show()
                f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all')
                row, col = 0, 0
Exemplo n.º 9
0
def main():
    # Local imports (imports that require tensorflow)
    from calamari_ocr.ocr.scenario import Scenario
    from calamari_ocr.ocr.dataset.data import Data
    from calamari_ocr.ocr.evaluator import Evaluator

    parser = ArgumentParser()
    parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--gt", nargs="+", required=True,
                        help="Ground truth files (.gt.txt extension). "
                             "Optionally, you can pass a single json file defining all parameters.")
    parser.add_argument("--pred", nargs="+", default=None,
                        help="Prediction files if provided. Else files with .pred.txt are expected at the same "
                             "location as the gt.")
    parser.add_argument("--pred_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--pred_ext", type=str, default=".pred.txt",
                        help="Extension of the predicted text files")
    parser.add_argument("--n_confusions", type=int, default=10,
                        help="Only print n most common confusions. Defaults to 10, use -1 for all.")
    parser.add_argument("--n_worst_lines", type=int, default=0,
                        help="Print the n worst recognized text lines with its error")
    parser.add_argument("--xlsx_output", type=str,
                        help="Optionally write a xlsx file with the evaluation results")
    parser.add_argument("--num_threads", type=int, default=1,
                        help="Number of threads to use for evaluation")
    parser.add_argument("--non_existing_file_handling_mode", type=str, default="error",
                        help="How to handle non existing .pred.txt files. Possible modes: skip, empty, error. "
                             "'Skip' will simply skip the evaluation of that file (not counting it to errors). "
                             "'Empty' will handle this file as would it be empty (fully checking for errors)."
                             "'Error' will throw an exception if a file is not existing. This is the default behaviour.")
    parser.add_argument("--skip_empty_gt", action="store_true", default=False,
                        help="Ignore lines of the gt that are empty.")
    parser.add_argument("--no_progress_bars", action="store_true",
                        help="Do not show any progress bars")
    parser.add_argument("--checkpoint", type=str, default=None,
                        help="Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)")

    # page xml specific args
    parser.add_argument("--pagexml_gt_text_index", default=0)
    parser.add_argument("--pagexml_pred_text_index", default=1)

    args = parser.parse_args()

    # check if loading a json file
    if len(args.gt) == 1 and args.gt[0].endswith("json"):
        with open(args.gt[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    logger.info("Resolving files")
    gt_files = sorted(glob_all(args.gt))

    if args.pred:
        pred_files = sorted(glob_all(args.pred))
    else:
        pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files]
        args.pred_dataset = args.dataset

    if args.non_existing_file_handling_mode.lower() == "skip":
        non_existing_pred = [p for p in pred_files if not os.path.exists(p)]
        for f in non_existing_pred:
            idx = pred_files.index(f)
            del pred_files[idx]
            del gt_files[idx]

    data_params = Data.get_default_params()
    if args.checkpoint:
        saved_model = SavedCalamariModel(args.checkpoint, auto_update=True)
        trainer_params = Scenario.trainer_params_from_dict(saved_model.dict)
        data_params = trainer_params.scenario_params.data_params

    data = Data(data_params)

    gt_reader_args = FileDataReaderArgs(
        text_index=args.pagexml_gt_text_index
    )
    pred_reader_args = FileDataReaderArgs(
        text_index=args.pagexml_pred_text_index
    )
    gt_data_set = PipelineParams(
        type=args.dataset,
        text_files=gt_files,
        data_reader_args=gt_reader_args,
        skip_invalid=args.skip_empty_gt,
    )
    pred_data_set = PipelineParams(
        type=args.pred_dataset,
        text_files=pred_files,
        data_reader_args=pred_reader_args,
    )

    evaluator = Evaluator(data=data)
    evaluator.preload_gt(gt_dataset=gt_data_set)
    r = evaluator.run(gt_dataset=gt_data_set, pred_dataset=pred_data_set, processes=args.num_threads,
                      progress_bar=not args.no_progress_bars)

    # TODO: More output
    print("Evaluation result")
    print("=================")
    print("")
    print("Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)".format(
        r["avg_ler"], r["total_char_errs"], r["total_chars"], r["total_sync_errs"]))

    # sort descending
    print_confusions(r, args.n_confusions)

    print_worst_lines(r, data.create_pipeline(PipelineMode.Targets, gt_data_set).reader().samples(), args.n_worst_lines)

    if args.xlsx_output:
        write_xlsx(args.xlsx_output,
                   [{
                       "prefix": "evaluation",
                       "results": r,
                       "gt_files": gt_files,
                   }])