Exemplo n.º 1
0
    def __init__(
        self,
        mode: PipelineMode,
        files: List[str] = None,
        xmlfiles: List[str] = None,
        skip_invalid=False,
        remove_invalid=True,
        binary=False,
        non_existing_as_empty=False,
    ):
        """ Create a dataset from a Path as String

        Parameters
         ----------
        files : [], required
            image files
        skip_invalid : bool, optional
            skip invalid files
        remove_invalid : bool, optional
            remove invalid files
        """

        super().__init__(mode, skip_invalid, remove_invalid)

        self.xmlfiles = xmlfiles if xmlfiles else []
        self.files = files if files else []

        self._non_existing_as_empty = non_existing_as_empty
        if len(self.xmlfiles) == 0:
            from calamari_ocr.ocr.dataset import DataSetType
            self.xmlfiles = [
                split_all_ext(p)[0] +
                DataSetType.gt_extension(DataSetType.ABBYY) for p in files
            ]

        if len(self.files) == 0:
            self.files = [None] * len(self.xmlfiles)

        self.book = XMLReader(self.files, self.xmlfiles, skip_invalid,
                              remove_invalid).read()
        self.binary = binary

        for p, page in enumerate(self.book.pages):
            for l, line in enumerate(page.getLines()):
                for f, fo in enumerate(line.formats):
                    self.add_sample({
                        "image_path":
                        page.imgFile,
                        "xml_path":
                        page.xmlFile,
                        "id":
                        "{}_{}_{}_{}".format(
                            os.path.splitext(page.xmlFile if page.
                                             xmlFile else page.imgFile)[0], p,
                            l, f),
                        "line":
                        line,
                        "format":
                        fo,
                    })
Exemplo n.º 2
0
    def prepare_for_mode(self, mode: PipelineMode) -> 'PipelineParams':
        from calamari_ocr.ocr.dataset.datareader.factory import DataReaderFactory
        assert (self.type is not None)
        params_out = deepcopy(self)
        # Training dataset
        logger.info("Resolving input files")
        if isinstance(self.type, str):
            try:
                self.type = DataSetType.from_string(self.type)
            except ValueError:
                # Not a valid type, must be custom
                if self.type not in DataReaderFactory.CUSTOM_READERS:
                    raise KeyError(
                        f"DataSetType {self.type} is neither a standard DataSetType or preset as custom "
                        f"reader ({list(DataReaderFactory.CUSTOM_READERS.keys())})"
                    )
        if not isinstance(self.type, str) and self.type not in {
                DataSetType.RAW, DataSetType.GENERATED_LINE
        }:
            input_image_files = sorted(glob_all(
                self.files)) if self.files else None

            if not self.text_files:
                if self.gt_extension:
                    gt_txt_files = [
                        split_all_ext(f)[0] + self.gt_extension
                        for f in input_image_files
                    ]
                else:
                    gt_txt_files = None
            else:
                gt_txt_files = sorted(glob_all(self.text_files))
                if mode in INPUT_PROCESSOR:
                    input_image_files, gt_txt_files = keep_files_with_same_file_name(
                        input_image_files, gt_txt_files)
                    for img, gt in zip(input_image_files, gt_txt_files):
                        if split_all_ext(
                                os.path.basename(img))[0] != split_all_ext(
                                    os.path.basename(gt))[0]:
                            raise Exception(
                                "Expected identical basenames of file: {} and {}"
                                .format(img, gt))
                else:
                    input_image_files = None

            if mode in {PipelineMode.Training, PipelineMode.Evaluation}:
                if len(set(gt_txt_files)) != len(gt_txt_files):
                    logger.warning(
                        "Some ground truth text files occur more than once in the data set "
                        "(ignore this warning, if this was intended).")
                if len(set(input_image_files)) != len(input_image_files):
                    logger.warning(
                        "Some images occur more than once in the data set. "
                        "This warning should usually not be ignored.")

            params_out.files = input_image_files
            params_out.text_files = gt_txt_files
        return params_out
Exemplo n.º 3
0
    def finish_chunck(self):
        if len(self.text) == 0:
            return

        codec = self.compute_codec()

        filename = "{}_{:03d}{}".format(self.output_filename, self.current_chunk, DataSetType.gt_extension(DataSetType.HDF5))
        self.files.append(filename)
        file = h5py.File(filename, 'w')
        dti32 = h5py.special_dtype(vlen=np.dtype('int32'))
        dtui8 = h5py.special_dtype(vlen=np.dtype('uint8'))
        file.create_dataset('transcripts', (len(self.text),), dtype=dti32, compression='gzip')
        file.create_dataset('images_dims', data=[d.shape for d in self.data], dtype=int)
        file.create_dataset('images', (len(self.text),), dtype=dtui8, compression='gzip')
        file.create_dataset('codec', data=list(map(ord, codec)))
        file['transcripts'][...] = [list(map(codec.index, d)) for d in self.text]
        file['images'][...] = [d.reshape(-1) for d in self.data]
        file.close()

        self.current_chunk += 1
        self.data = []
        self.text = []
Exemplo n.º 4
0
    def create_data_reader(
        cls,
        type: Union[DataSetType, str],
        mode: PipelineMode,
        images: List[str] = None,
        texts: List[str] = None,
        skip_invalid=False,
        remove_invalid=True,
        non_existing_as_empty=False,
        args: FileDataReaderArgs = None,
    ) -> DataReader:
        if images is None:
            images = []

        if texts is None:
            texts = []

        if args is None:
            args = dict()

        if type in cls.CUSTOM_READERS:
            return cls.CUSTOM_READERS[type](
                mode=mode,
                images=images,
                texts=texts,
                skip_invalid=skip_invalid,
                remove_invalid=remove_invalid,
                non_existing_as_empty=non_existing_as_empty,
                args=args,
            )

        if DataSetType.files(type):
            if images:
                images.sort()

            if texts:
                texts.sort()

            if images and texts and len(images) > 0 and len(texts) > 0:
                images, texts = keep_files_with_same_file_name(images, texts)

        if type == DataSetType.RAW:
            from calamari_ocr.ocr.dataset.datareader.raw import RawDataReader
            return RawDataReader(mode, images, texts)

        elif type == DataSetType.FILE:
            from calamari_ocr.ocr.dataset.datareader.file import FileDataReader
            return FileDataReader(mode,
                                  images,
                                  texts,
                                  skip_invalid=skip_invalid,
                                  remove_invalid=remove_invalid,
                                  non_existing_as_empty=non_existing_as_empty)
        elif type == DataSetType.ABBYY:
            from calamari_ocr.ocr.dataset.datareader.abbyy import AbbyyReader
            return AbbyyReader(mode,
                               images,
                               texts,
                               skip_invalid=skip_invalid,
                               remove_invalid=remove_invalid,
                               non_existing_as_empty=non_existing_as_empty)
        elif type == DataSetType.PAGEXML:
            from calamari_ocr.ocr.dataset.datareader.pagexml.reader import PageXMLReader
            return PageXMLReader(mode,
                                 images,
                                 texts,
                                 skip_invalid=skip_invalid,
                                 remove_invalid=remove_invalid,
                                 non_existing_as_empty=non_existing_as_empty,
                                 args=args)
        elif type == DataSetType.HDF5:
            from calamari_ocr.ocr.dataset.datareader.hdf5 import Hdf5Reader
            return Hdf5Reader(mode, images, texts)
        elif type == DataSetType.EXTENDED_PREDICTION:
            from calamari_ocr.ocr.dataset.extended_prediction_dataset import ExtendedPredictionDataSet
            return ExtendedPredictionDataSet(texts=texts)
        elif type == DataSetType.GENERATED_LINE:
            from calamari_ocr.ocr.dataset.datareader.generated_line_dataset.dataset import GeneratedLineDataset
            return GeneratedLineDataset(mode, args=args)
        else:
            raise Exception("Unsupported dataset type {}".format(type))
Exemplo n.º 5
0
def run(args):
    # local imports (to prevent tensorflow from being imported to early)
    from calamari_ocr.ocr.scenario import Scenario
    from calamari_ocr.ocr.dataset.data import Data

    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                if key == 'dataset' or key == 'validation_dataset':
                    setattr(args, key, DataSetType.from_string(value))
                else:
                    setattr(args, key, value)

    check_train_args(args)

    if args.output_dir is not None:
        args.output_dir = os.path.abspath(args.output_dir)
        setup_log(args.output_dir, append=False)

    # parse whitelist
    whitelist = args.whitelist
    if len(whitelist) == 1:
        whitelist = list(whitelist[0])

    whitelist_files = glob_all(args.whitelist_files)
    for f in whitelist_files:
        with open(f) as txt:
            whitelist += list(txt.read())

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    if args.validation_extension is None:
        args.validation_extension = DataSetType.gt_extension(
            args.validation_dataset)

    if args.dataset == DataSetType.GENERATED_LINE or args.validation_dataset == DataSetType.GENERATED_LINE:
        if args.text_generator_params is not None:
            with open(args.text_generator_params, 'r') as f:
                args.text_generator_params = TextGeneratorParams.from_json(
                    f.read())
        else:
            args.text_generator_params = TextGeneratorParams()

        if args.line_generator_params is not None:
            with open(args.line_generator_params, 'r') as f:
                args.line_generator_params = LineGeneratorParams.from_json(
                    f.read())
        else:
            args.line_generator_params = LineGeneratorParams()

    dataset_args = FileDataReaderArgs(
        line_generator_params=args.line_generator_params,
        text_generator_params=args.text_generator_params,
        pad=args.dataset_pad,
        text_index=args.pagexml_text_index,
    )

    params: TrainerParams = Scenario.default_trainer_params()

    # =================================================================================================================
    # Data Params
    data_params: DataParams = params.scenario_params.data_params
    data_params.train = PipelineParams(
        type=args.dataset,
        skip_invalid=not args.no_skip_invalid_gt,
        remove_invalid=True,
        files=args.files,
        text_files=args.text_files,
        gt_extension=args.gt_extension,
        data_reader_args=dataset_args,
        batch_size=args.batch_size,
        num_processes=args.num_threads,
    )
    if args.validation_split_ratio:
        if args.validation is not None:
            raise ValueError("Set either validation_split_ratio or validation")
        if not 0 < args.validation_split_ratio < 1:
            raise ValueError("validation_split_ratio must be in (0, 1)")

        # resolve all files so we can split them
        data_params.train = data_params.train.prepare_for_mode(
            PipelineMode.Training)
        n = int(args.validation_split_ratio * len(data_params.train.files))
        logger.info(
            f"Splitting training and validation files with ratio {args.validation_split_ratio}: "
            f"{n}/{len(data_params.train.files) - n} for validation/training.")
        indices = list(range(len(data_params.train.files)))
        shuffle(indices)
        all_files = data_params.train.files
        all_text_files = data_params.train.text_files

        # split train and val img/gt files. Use train settings
        data_params.train.files = [all_files[i] for i in indices[:n]]
        if all_text_files is not None:
            assert (len(all_text_files) == len(all_files))
            data_params.train.text_files = [
                all_text_files[i] for i in indices[:n]
            ]
        data_params.val = PipelineParams(
            type=args.dataset,
            skip_invalid=not args.no_skip_invalid_gt,
            remove_invalid=True,
            files=[all_files[i] for i in indices[n:]],
            text_files=[all_text_files[i] for i in indices[n:]]
            if data_params.train.text_files is not None else None,
            gt_extension=args.gt_extension,
            data_reader_args=dataset_args,
            batch_size=args.batch_size,
            num_processes=args.num_threads,
        )
    elif args.validation:
        data_params.val = PipelineParams(
            type=args.validation_dataset,
            files=args.validation,
            text_files=args.validation_text_files,
            skip_invalid=not args.no_skip_invalid_gt,
            gt_extension=args.validation_extension,
            data_reader_args=dataset_args,
            batch_size=args.batch_size,
            num_processes=args.num_threads,
        )
    else:
        data_params.val = None

    data_params.pre_processors_ = SamplePipelineParams(run_parallel=True)
    data_params.post_processors_.run_parallel = SamplePipelineParams(
        run_parallel=False,
        sample_processors=[
            DataProcessorFactoryParams(ReshapeOutputsProcessor.__name__),
            DataProcessorFactoryParams(CTCDecoderProcessor.__name__),
        ])

    for p in args.data_preprocessing:
        p_p = Data.data_processor_factory().processors[p].default_params()
        if 'pad' in p_p:
            p_p['pad'] = args.pad
        data_params.pre_processors_.sample_processors.append(
            DataProcessorFactoryParams(p, INPUT_PROCESSOR, p_p))

    # Text pre processing (reading)
    data_params.pre_processors_.sample_processors.extend([
        DataProcessorFactoryParams(
            TextNormalizer.__name__, TARGETS_PROCESSOR,
            {'unicode_normalization': args.text_normalization}),
        DataProcessorFactoryParams(
            TextRegularizer.__name__, TARGETS_PROCESSOR, {
                'replacements':
                default_text_regularizer_replacements(args.text_regularization)
            }),
        DataProcessorFactoryParams(StripTextProcessor.__name__,
                                   TARGETS_PROCESSOR)
    ])

    # Text post processing (prediction)
    data_params.post_processors_.sample_processors.extend([
        DataProcessorFactoryParams(
            TextNormalizer.__name__, TARGETS_PROCESSOR,
            {'unicode_normalization': args.text_normalization}),
        DataProcessorFactoryParams(
            TextRegularizer.__name__, TARGETS_PROCESSOR, {
                'replacements':
                default_text_regularizer_replacements(args.text_regularization)
            }),
        DataProcessorFactoryParams(StripTextProcessor.__name__,
                                   TARGETS_PROCESSOR)
    ])
    if args.bidi_dir:
        data_params.pre_processors_.sample_processors.append(
            DataProcessorFactoryParams(BidiTextProcessor.__name__,
                                       TARGETS_PROCESSOR,
                                       {'bidi_direction': args.bidi_dir}))
        data_params.post_processors_.sample_processors.append(
            DataProcessorFactoryParams(BidiTextProcessor.__name__,
                                       TARGETS_PROCESSOR,
                                       {'bidi_direction': args.bidi_dir}))

    data_params.pre_processors_.sample_processors.extend([
        DataProcessorFactoryParams(AugmentationProcessor.__name__,
                                   {PipelineMode.Training},
                                   {'augmenter_type': 'simple'}),
        DataProcessorFactoryParams(PrepareSampleProcessor.__name__,
                                   INPUT_PROCESSOR),
    ])

    data_params.data_aug_params = DataAugmentationAmount.from_factor(
        args.n_augmentations)
    data_params.line_height_ = args.line_height

    # =================================================================================================================
    # Trainer Params
    params.calc_ema = args.ema_weights
    params.verbose = args.train_verbose
    params.force_eager = args.debug
    params.skip_model_load_test = not args.debug
    params.scenario_params.debug_graph_construction = args.debug
    params.epochs = args.epochs
    params.samples_per_epoch = int(
        args.samples_per_epoch) if args.samples_per_epoch >= 1 else -1
    params.scale_epoch_size = abs(
        args.samples_per_epoch) if args.samples_per_epoch < 1 else 1
    params.skip_load_model_test = True
    params.scenario_params.export_frozen = False
    params.checkpoint_save_freq_ = args.checkpoint_frequency if args.checkpoint_frequency >= 0 else args.early_stopping_frequency
    params.checkpoint_dir = args.output_dir
    params.test_every_n = args.display
    params.skip_invalid_gt = not args.no_skip_invalid_gt
    params.data_aug_retrain_on_original = not args.only_train_on_augmented
    params.use_training_as_validation = args.use_train_as_val
    if args.seed > 0:
        params.random_seed = args.seed

    params.optimizer_params.clip_grad = args.gradient_clipping_norm
    params.codec_whitelist = whitelist
    params.keep_loaded_codec = args.keep_loaded_codec
    params.preload_training = not args.train_data_on_the_fly
    params.preload_validation = not args.validation_data_on_the_fly
    params.warmstart_params.model = args.weights

    params.auto_compute_codec = not args.no_auto_compute_codec
    params.progress_bar = not args.no_progress_bars

    params.early_stopping_params.frequency = args.early_stopping_frequency
    params.early_stopping_params.upper_threshold = 0.9
    params.early_stopping_params.lower_threshold = 1.0 - args.early_stopping_at_accuracy
    params.early_stopping_params.n_to_go = args.early_stopping_nbest
    params.early_stopping_params.best_model_name = ''
    params.early_stopping_params.best_model_output_dir = args.early_stopping_best_model_output_dir
    params.scenario_params.default_serve_dir_ = f'{args.early_stopping_best_model_prefix}.ckpt.h5'
    params.scenario_params.trainer_params_filename_ = f'{args.early_stopping_best_model_prefix}.ckpt.json'

    # =================================================================================================================
    # Model params
    params_from_definition_string(args.network, params)
    params.scenario_params.model_params.ensemble = args.ensemble
    params.scenario_params.model_params.masking_mode = args.masking_mode

    scenario = Scenario(params.scenario_params)
    trainer = scenario.create_trainer(params)
    trainer.train()
Exemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--version',
                        action='version',
                        version='%(prog)s v' + __version__)
    parser.add_argument(
        "--files",
        nargs="+",
        help=
        "List all image files that shall be processed. Ground truth fils with the same "
        "base name but with '.gt.txt' as extension are required at the same location",
        required=True)
    parser.add_argument(
        "--text_files",
        nargs="+",
        default=None,
        help="Optional list of GT files if they are in other directory")
    parser.add_argument(
        "--gt_extension",
        default=None,
        help="Default extension of the gt files (expected to exist in same dir)"
    )
    parser.add_argument("--dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument("--line_height",
                        type=int,
                        default=48,
                        help="The line height")
    parser.add_argument("--pad",
                        type=int,
                        default=16,
                        help="Padding (left right) of the line")
    parser.add_argument("--processes",
                        type=int,
                        default=1,
                        help="The number of threads to use for all operations")

    parser.add_argument("--n_cols", type=int, default=1)
    parser.add_argument("--n_rows", type=int, default=5)
    parser.add_argument("--select", type=int, nargs="+", default=[])

    # text normalization/regularization
    parser.add_argument(
        "--n_augmentations",
        type=float,
        default=0,
        help=
        "Amount of data augmentation per line (done before training). If this number is < 1 "
        "the amount is relative.")
    parser.add_argument("--text_regularization",
                        type=str,
                        nargs="+",
                        default=["extended"],
                        help="Text regularization to apply.")
    parser.add_argument(
        "--text_normalization",
        type=str,
        default="NFC",
        help="Unicode text normalization to apply. Defaults to NFC")
    parser.add_argument(
        "--data_preprocessing",
        nargs="+",
        type=str,
        choices=[
            k for k, p in Data.data_processor_factory().processors.items()
            if issubclass(p, ImageProcessor)
        ],
        default=[p.name for p in default_image_processors()])
    parser.add_argument(
        "--bidi_dir",
        type=str,
        default=None,
        choices=["ltr", "rtl", "auto"],
        help=
        "The default text direction when preprocessing bidirectional text. Supported values "
        "are 'auto' to automatically detect the direction, 'ltr' and 'rtl' for left-to-right and "
        "right-to-left, respectively")

    parser.add_argument("--preload",
                        action='store_true',
                        help='Simulate preloading')
    parser.add_argument("--as_validation",
                        action='store_true',
                        help="Access as validation instead of training data.")

    args = parser.parse_args()

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    dataset_args = FileDataReaderArgs(pad=args.pad, )

    data_params: DataParams = Data.get_default_params()
    data_params.train = PipelineParams(
        type=args.dataset,
        remove_invalid=True,
        files=args.files,
        text_files=args.text_files,
        gt_extension=args.gt_extension,
        data_reader_args=dataset_args,
        batch_size=1,
        num_processes=args.processes,
    )
    data_params.val = data_params.train

    data_params.pre_processors_ = SamplePipelineParams(run_parallel=True)
    data_params.post_processors_ = SamplePipelineParams(run_parallel=True)
    for p in args.data_preprocessing:
        p_p = Data.data_processor_factory().processors[p].default_params()
        if 'pad' in p_p:
            p_p['pad'] = args.pad
        data_params.pre_processors_.sample_processors.append(
            DataProcessorFactoryParams(p, INPUT_PROCESSOR, p_p))

    # Text pre processing (reading)
    data_params.pre_processors_.sample_processors.extend([
        DataProcessorFactoryParams(
            TextNormalizer.__name__, TARGETS_PROCESSOR,
            {'unicode_normalization': args.text_normalization}),
        DataProcessorFactoryParams(
            TextRegularizer.__name__, TARGETS_PROCESSOR, {
                'replacements':
                default_text_regularizer_replacements(args.text_regularization)
            }),
        DataProcessorFactoryParams(StripTextProcessor.__name__,
                                   TARGETS_PROCESSOR)
    ])

    # Text post processing (prediction)
    data_params.post_processors_.sample_processors.extend([
        DataProcessorFactoryParams(
            TextNormalizer.__name__, TARGETS_PROCESSOR,
            {'unicode_normalization': args.text_normalization}),
        DataProcessorFactoryParams(
            TextRegularizer.__name__, TARGETS_PROCESSOR, {
                'replacements':
                default_text_regularizer_replacements(args.text_regularization)
            }),
        DataProcessorFactoryParams(StripTextProcessor.__name__,
                                   TARGETS_PROCESSOR)
    ])
    if args.bidi_dir:
        data_params.pre_processors_.sample_processors.append(
            DataProcessorFactoryParams(BidiTextProcessor.__name__,
                                       TARGETS_PROCESSOR,
                                       {'bidi_direction': args.bidi_dir}))
        data_params.post_processors_.sample_processors.append(
            DataProcessorFactoryParams(BidiTextProcessor.__name__,
                                       TARGETS_PROCESSOR,
                                       {'bidi_direction': args.bidi_dir}))

    data_params.pre_processors_.sample_processors.extend([
        DataProcessorFactoryParams(AugmentationProcessor.__name__,
                                   {PipelineMode.Training},
                                   {'augmenter_type': 'simple'}),
        # DataProcessorFactoryParams(PrepareSampleProcessor.__name__),  # NOT THIS, since, we want to access raw input
    ])

    data_params.data_aug_params = DataAugmentationAmount.from_factor(
        args.n_augmentations)
    data_params.line_height_ = args.line_height

    data = Data(data_params)
    data_pipeline = data.get_val_data(
    ) if args.as_validation else data.get_train_data()
    if not args.preload:
        reader: DataReader = data_pipeline.reader()
        if len(args.select) == 0:
            args.select = range(len(reader))
        else:
            reader._samples = [reader.samples()[i] for i in args.select]
    else:
        data.preload()
        data_pipeline = data.get_val_data(
        ) if args.as_validation else data.get_train_data()
        samples = data_pipeline.samples
        if len(args.select) == 0:
            args.select = range(len(samples))
        else:
            data_pipeline.samples = [samples[i] for i in args.select]

    f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all')
    row, col = 0, 0
    with data_pipeline as dataset:
        for i, (id, sample) in enumerate(
                zip(args.select,
                    dataset.generate_input_samples(auto_repeat=False))):
            line, text, params = sample
            if args.n_cols == 1:
                ax[row].imshow(line.transpose())
                ax[row].set_title("ID: {}\n{}".format(id, text))
            else:
                ax[row, col].imshow(line.transpose())
                ax[row, col].set_title("ID: {}\n{}".format(id, text))

            row += 1
            if row == args.n_rows:
                row = 0
                col += 1

            if col == args.n_cols or i == len(dataset) - 1:
                plt.show()
                f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all')
                row, col = 0, 0
Exemplo n.º 7
0
def run(args):
    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        import json
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    # checks
    if args.extended_prediction_data_format not in ["pred", "json"]:
        raise Exception(
            "Only 'pred' and 'json' are allowed extended prediction data formats"
        )

    # add json as extension, resolve wildcard, expand user, ... and remove .json again
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json")
                       for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]
    args.extension = args.extension if args.extension else DataSetType.pred_extension(
        args.dataset)

    # create ctc decoder
    ctc_decoder_params = create_ctc_decoder_params(args)

    # create voter
    voter_params = VoterParams()
    voter_params.type = VoterType(args.voter)

    # load files
    input_image_files = glob_all(args.files)
    if args.text_files:
        args.text_files = glob_all(args.text_files)

    # skip invalid files and remove them, there wont be predictions of invalid files
    predict_params = PipelineParams(
        type=args.dataset,
        skip_invalid=True,
        remove_invalid=True,
        files=input_image_files,
        text_files=args.text_files,
        data_reader_args=FileDataReaderArgs(
            pad=args.dataset_pad,
            text_index=args.pagexml_text_index,
        ),
        batch_size=args.batch_size,
        num_processes=args.processes,
    )

    # predict for all models
    # TODO: Use CTC Decoder params
    from calamari_ocr.ocr.predict.predictor import MultiPredictor
    predictor = MultiPredictor.from_paths(
        checkpoints=args.checkpoint,
        voter_params=voter_params,
        predictor_params=PredictorParams(
            silent=True, progress_bar=not args.no_progress_bars))
    do_prediction = predictor.predict(predict_params)
    pipeline: CalamariPipeline = predictor.data.get_predict_data(
        predict_params)
    reader = pipeline.reader()
    if len(reader) == 0:
        raise Exception(
            "Empty dataset provided. Check your files argument (got {})!".
            format(args.files))

    avg_sentence_confidence = 0
    n_predictions = 0

    reader.prepare_store()

    # output the voted results to the appropriate files
    for s in do_prediction:
        inputs, (result, prediction), meta = s.inputs, s.outputs, s.meta
        sample = reader.sample_by_id(meta['id'])
        n_predictions += 1
        sentence = prediction.sentence

        avg_sentence_confidence += prediction.avg_char_probability
        if args.verbose:
            lr = "\u202A\u202B"
            logger.info("{}: '{}{}{}'".format(meta['id'],
                                              lr[get_base_level(sentence)],
                                              sentence, "\u202C"))

        output_dir = args.output_dir

        reader.store_text(sentence,
                          sample,
                          output_dir=output_dir,
                          extension=args.extension)

        if args.extended_prediction_data:
            ps = Predictions()
            ps.line_path = sample[
                'image_path'] if 'image_path' in sample else sample['id']
            ps.predictions.extend([prediction] +
                                  [r.prediction for r in result])
            output_dir = output_dir if output_dir else os.path.dirname(
                ps.line_path)
            if not os.path.exists(output_dir):
                os.mkdir(output_dir)

            if args.extended_prediction_data_format == "pred":
                data = zlib.compress(
                    ps.to_json(indent=2, ensure_ascii=False).encode('utf-8'))
            elif args.extended_prediction_data_format == "json":
                # remove logits
                for p in ps.predictions:
                    p.logits = None

                data = ps.to_json(indent=2)
            else:
                raise Exception("Unknown prediction format.")

            reader.store_extended_prediction(
                data,
                sample,
                output_dir=output_dir,
                extension=args.extended_prediction_data_format)

    logger.info("Average sentence confidence: {:.2%}".format(
        avg_sentence_confidence / n_predictions))

    reader.store(args.extension)
    logger.info("All prediction files written")