예제 #1
0
    def predict_books(self, books, models, pageupload=False, text_index=1):
        if type(books) == str:
            books = [books]
        if type(models) == str:
            models = [models]
        dset = Nash5DataSet(DataSetMode.PREDICT, self.cachefile, books)

        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value("confidence_voter_default_ctc".upper())
        voter = voter_from_proto(voter_params)

        # predict for all models
        predictor = MultiPredictor(checkpoints=models, data_preproc=NoopDataPreprocessor(), batch_size=1, processes=1)
        do_prediction = predictor.predict_dataset(dset, progress_bar=True)

        avg_sentence_confidence = 0
        n_predictions = 0
        # output the voted results to the appropriate files
        for result, sample in do_prediction:
            n_predictions += 1
            for i, p in enumerate(result):
                p.prediction.id = "fold_{}".format(i)

            # vote the results (if only one model is given, this will just return the sentences)
            prediction = voter.vote_prediction_result(result)
            prediction.id = "voted"
            sentence = prediction.sentence
            avg_sentence_confidence += prediction.avg_char_probability

            dset.store_text(sentence, sample, output_dir=None, extension=".pred.txt")
        print("Average sentence confidence: {:.2%}".format(avg_sentence_confidence / n_predictions))

        dset.store()
        print("All files written")
예제 #2
0
 def __init__(self, settings: AlgorithmPredictorSettings):
     super().__init__(settings)
     self.predictor = MultiPredictor(glob_all([s + '/omr_best*.ckpt.json' for s in [settings.model.path]]),
                                     ctc_decoder_params=settings.params.ctcDecoder.params)
     self.height = self.predictor.predictors[0].network_params.features
     voter_params = VoterParams()
     voter_params.type = VoterParams.CONFIDENCE_VOTER_DEFAULT_CTC
     self.voter = voter_from_proto(voter_params)
예제 #3
0
    def evaluate_books(self, books, models, mode="auto", sample=-1):
        if type(books) == str:
            books = [books]
        if type(models) == str:
            models = [models]
        results = {}
        if mode == "auto":
            with h5py.File(self.cachefile, 'r', libver='latest', swmr=True) as cache:
                for b in books:
                    for p in cache[b]:
                        for s in cache[b][p]:
                            if "text" in cache[b][p][s].attrs:
                                mode = "eval"
                                break
                        if mode != "auto":
                            break
                    if mode != "auto":
                        break
            if mode == "auto":
                mode = "conf"

        if mode == "conf":
            dset = Nash5DataSet(DataSetMode.PREDICT, self.cachefile, books)
        else:
            dset = Nash5DataSet(DataSetMode.EVAL, self.cachefile, books)

        if 0 < sample < len(dset):
            delsamples = random.sample(dset._samples, len(dset) - sample)
            for s in delsamples:
                dset._samples.remove(s)

        if mode == "conf":
            for model in models:
                if isinstance(model, str):
                    model = [model]
                predictor = MultiPredictor(checkpoints=model, data_preproc=NoopDataPreprocessor(), batch_size=1, processes=1)
                voter_params = VoterParams()
                voter_params.type = VoterParams.Type.Value("confidence_voter_default_ctc".upper())
                voter = voter_from_proto(voter_params)
                do_prediction = predictor.predict_dataset(dset, progress_bar=True)
                avg_sentence_confidence = 0
                n_predictions = 0
                for result, sample in do_prediction:
                    n_predictions += 1
                    prediction = voter.vote_prediction_result(result)
                    avg_sentence_confidence += prediction.avg_char_probability
                results["/".join(model)] = avg_sentence_confidence / n_predictions

        else:
            for model in models:
                if isinstance(model, str):
                    model = [model]
                predictor = MultiPredictor(checkpoint=model, data_preproc=NoopDataPreprocessor(), batch_size=1, processes=1, with_gt=True)
                out_gen = predictor.predict_dataset(dset, progress_bar=True, apply_preproc=False)
                result = Evaluator.evaluate_single_list(map(Evaluator.evaluate_single_args,
                            map(lambda d: tuple([''.join(d[0].ground_truth), ''.join(d[0].chars)]), out_gen)))
                results["/".join(model)] = 1 - result["avg_ler"]
        return results
예제 #4
0
    def _init_calamari(self):
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = TF_CPP_MIN_LOG_LEVEL

        checkpoints = glob(self.parameter['checkpoint'])
        self.predictor = MultiPredictor(checkpoints=checkpoints)

        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value(self.parameter['voter'].upper())
        self.voter = voter_from_proto(voter_params)
예제 #5
0
    def predict_books(self, books, models, pageupload=True, text_index=1):
        if pageupload == False:
            print("""Warning: trying to save results to the hdf5-Cache may fail due to some issue
                  with file access from multiple threads. It should work, however, if you set
                  export HDF5_USE_FILE_LOCKING='FALSE'.""")
        if type(books) == str:
            books = [books]
        if type(models) == str:
            models = [models]
        dset = Nash5DataSet(DataSetMode.PREDICT, self.cachefile, books)

        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value("confidence_voter_default_ctc".upper())
        voter = voter_from_proto(voter_params)

        # predict for all models
        predictor = MultiPredictor(checkpoints=models, data_preproc=PadNoopDataPreprocessor(), batch_size=1, processes=1)
        do_prediction = predictor.predict_dataset(dset, progress_bar=True)

        avg_sentence_confidence = 0
        n_predictions = 0
        # output the voted results to the appropriate files
        for result, sample in do_prediction:
            n_predictions += 1
            for i, p in enumerate(result):
                p.prediction.id = "fold_{}".format(i)

            # vote the results (if only one model is given, this will just return the sentences)
            prediction = voter.vote_prediction_result(result)
            prediction.id = "voted"
            sentence = prediction.sentence
            avg_sentence_confidence += prediction.avg_char_probability

            dset.store_text(sentence, sample, output_dir=None, extension=".pred.txt")
        avg_conf = avg_sentence_confidence / n_predictions if n_predictions else 0
        print("Average sentence confidence: {:.2%}".format(avg_conf))

        if pageupload:
            ocrdata = {}
            for lname, text in dset.predictions.items():
                _, b, p, l = lname.split("/")
                if b not in ocrdata:
                    ocrdata[b] = {}
                if p not in ocrdata[b]:
                    ocrdata[b][p] = {}
                ocrdata[b][p][l] = text

            data = {"ocrdata": ocrdata, "index": text_index}
            self.session.post(self.baseurl+"/_ocrdata",
                              data=gzip.compress(json.dumps(data).encode("utf-8")),
                              headers={"Content-Type": "application/json;charset=UTF-8",
                                       "Content-Encoding": "gzip"})
            print("Results uploaded")
        else:
            dset.store()
            print("All files written")
예제 #6
0
    def __init__(self, params: SymbolDetectionPredictorParameters):
        super().__init__(params)
        ds_params = params.symbol_detection_params
        ds_params.masks_as_input = True
        ds_params.apply_fcn_height = 80
        ds_params.apply_fcn_model = os.path.join(params.checkpoints[0],
                                                 'pc_model')

        self.predictor = MultiPredictor(
            glob_all([s + '/omr_best*.ckpt.json' for s in params.checkpoints]))
        self.height = self.predictor.predictors[0].network_params.features
        voter_params = VoterParams()
        voter_params.type = VoterParams.CONFIDENCE_VOTER_DEFAULT_CTC
        self.voter = voter_from_proto(voter_params)
예제 #7
0
    def setup(self):
        """
        Set up the model prior to processing.
        """
        resolved = self.resolve_resource(self.parameter['checkpoint_dir'])
        checkpoints = glob('%s/*.ckpt.json' % resolved)
        self.predictor = MultiPredictor(checkpoints=checkpoints)

        self.network_input_channels = self.predictor.predictors[
            0].network.input_channels
        #self.network_input_channels = self.predictor.predictors[0].network_params.channels # not used!
        # binarization = self.predictor.predictors[0].model_params.data_preprocessor.binarization # not used!
        # self.features = ('' if self.network_input_channels != 1 else
        #                  'binarized' if binarization != 'GRAY' else
        #                  'grayscale_normalized')
        self.features = ''

        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value(
            self.parameter['voter'].upper())
        self.voter = voter_from_proto(voter_params)
예제 #8
0
    def evaluate_books(self, books, models, rtl=False, mode="auto", sample=-1):
        if type(books) == str:
            books = [books]
        if type(models) == str:
            models = [models]
        results = {}
        if mode == "auto":
            with h5py.File(self.cachefile, 'r', libver='latest',
                           swmr=True) as cache:
                for b in books:
                    for p in cache[b]:
                        for s in cache[b][p]:
                            if "text" in cache[b][p][s].attrs:
                                mode = "eval"
                                break
                        if mode != "auto":
                            break
                    if mode != "auto":
                        break
            if mode == "auto":
                mode = "conf"

        if mode == "conf":
            dset = Nash5DataSet(DataSetMode.PREDICT, self.cachefile, books)
        else:
            dset = Nash5DataSet(DataSetMode.TRAIN, self.cachefile, books)
            dset.mode = DataSetMode.PREDICT  # otherwise results are randomised

        if 0 < sample < len(dset):
            delsamples = random.sample(dset._samples, len(dset) - sample)
            for s in delsamples:
                dset._samples.remove(s)

        if mode == "conf":
            #dset = dset.to_raw_input_dataset(processes=1, progress_bar=True)
            for model in models:
                if isinstance(model, str):
                    model = [model]
                predictor = MultiPredictor(checkpoints=model,
                                           data_preproc=NoopDataPreprocessor(),
                                           batch_size=1,
                                           processes=1)
                voter_params = VoterParams()
                voter_params.type = VoterParams.Type.Value(
                    "confidence_voter_default_ctc".upper())
                voter = voter_from_proto(voter_params)
                do_prediction = predictor.predict_dataset(dset,
                                                          progress_bar=True)
                avg_sentence_confidence = 0
                n_predictions = 0
                for result, sample in do_prediction:
                    n_predictions += 1
                    prediction = voter.vote_prediction_result(result)
                    avg_sentence_confidence += prediction.avg_char_probability

                results["/".join(
                    model)] = avg_sentence_confidence / n_predictions

        else:
            for model in models:
                if isinstance(model, str):
                    model = [model]

                predictor = MultiPredictor(checkpoints=model,
                                           data_preproc=NoopDataPreprocessor(),
                                           batch_size=1,
                                           processes=1)

                voter_params = VoterParams()
                voter_params.type = VoterParams.Type.Value(
                    "confidence_voter_default_ctc".upper())
                voter = voter_from_proto(voter_params)

                out_gen = predictor.predict_dataset(dset, progress_bar=True)

                preproc = self.bidi_preproc if rtl else self.txt_preproc

                pred_dset = RawDataSet(DataSetMode.EVAL,
                                       texts=preproc.apply([
                                           voter.vote_prediction_result(
                                               d[0]).sentence for d in out_gen
                                       ]))

                evaluator = Evaluator(text_preprocessor=NoopTextProcessor(),
                                      skip_empty_gt=False)
                r = evaluator.run(gt_dataset=dset,
                                  pred_dataset=pred_dset,
                                  processes=1,
                                  progress_bar=True)

                results["/".join(model)] = 1 - r["avg_ler"]
        return results
예제 #9
0
def run(args):

    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        import json
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    # checks
    if args.extended_prediction_data_format not in ["pred", "json"]:
        raise Exception(
            "Only 'pred' and 'json' are allowed extended prediction data formats"
        )

    # add json as extension, resolve wildcard, expand user, ... and remove .json again
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json")
                       for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]

    # create voter
    voter_params = VoterParams()
    voter_params.type = VoterParams.Type.Value(args.voter.upper())
    voter = voter_from_proto(voter_params)

    # load files
    input_image_files = glob_all(args.files)
    if args.text_files:
        args.text_files = glob_all(args.text_files)

    # skip invalid files and remove them, there wont be predictions of invalid files
    dataset = create_dataset(
        args.dataset,
        DataSetMode.PREDICT,
        input_image_files,
        args.text_files,
        skip_invalid=True,
        remove_invalid=True,
        args={'text_index': args.pagexml_text_index},
    )

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception(
            "Empty dataset provided. Check your files argument (got {})!".
            format(args.files))

    # predict for all models
    predictor = MultiPredictor(checkpoints=args.checkpoint,
                               batch_size=args.batch_size,
                               processes=args.processes)
    do_prediction = predictor.predict_dataset(
        dataset, progress_bar=not args.no_progress_bars)

    avg_sentence_confidence = 0
    n_predictions = 0

    # output the voted results to the appropriate files
    for result, sample in do_prediction:
        n_predictions += 1
        for i, p in enumerate(result):
            p.prediction.id = "fold_{}".format(i)

        # vote the results (if only one model is given, this will just return the sentences)
        prediction = voter.vote_prediction_result(result)
        prediction.id = "voted"
        sentence = prediction.sentence
        avg_sentence_confidence += prediction.avg_char_probability
        if args.verbose:
            lr = "\u202A\u202B"
            print("{}: '{}{}{}'".format(sample['id'],
                                        lr[get_base_level(sentence)], sentence,
                                        "\u202C"))

        output_dir = args.output_dir

        dataset.store_text(sentence,
                           sample,
                           output_dir=output_dir,
                           extension=".pred.txt")

        if args.extended_prediction_data:
            ps = Predictions()
            ps.line_path = sample[
                'image_path'] if 'image_path' in sample else sample['id']
            ps.predictions.extend([prediction] +
                                  [r.prediction for r in result])
            output_dir = output_dir if output_dir else os.path.dirname(
                ps.line_path)
            if not os.path.exists(output_dir):
                os.mkdir(output_dir)

            if args.extended_prediction_data_format == "pred":
                with open(os.path.join(output_dir, sample['id'] + ".pred"),
                          'wb') as f:
                    f.write(ps.SerializeToString())
            elif args.extended_prediction_data_format == "json":
                with open(os.path.join(output_dir, sample['id'] + ".json"),
                          'w') as f:
                    # remove logits
                    for prediction in ps.predictions:
                        prediction.logits.rows = 0
                        prediction.logits.cols = 0
                        prediction.logits.data[:] = []

                    f.write(
                        MessageToJson(ps, including_default_value_fields=True))
            else:
                raise Exception("Unknown prediction format.")

    print("Average sentence confidence: {:.2%}".format(
        avg_sentence_confidence / n_predictions))

    dataset.store()
    print("All files written")
예제 #10
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--files",
                        nargs="+",
                        required=True,
                        default=[],
                        help="List all image files that shall be processed")
    parser.add_argument("--checkpoint",
                        type=str,
                        nargs="+",
                        default=[],
                        help="Path to the checkpoint without file extension")
    parser.add_argument("-j",
                        "--processes",
                        type=int,
                        default=1,
                        help="Number of processes to use")
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help=
        "The batch size during the prediction (number of lines to process in parallel)"
    )
    parser.add_argument("--verbose",
                        action="store_true",
                        help="Print additional information")
    parser.add_argument(
        "--voter",
        type=str,
        default="confidence_voter_default_ctc",
        help=
        "The voting algorithm to use. Possible values: confidence_voter_default_ctc (default), "
        "confidence_voter_fuzzy_ctc, sequence_voter")
    parser.add_argument(
        "--output_dir",
        type=str,
        help=
        "By default the prediction files will be written to the same directory as the given files. "
        "You can use this argument to specify a specific output dir for the prediction files."
    )
    parser.add_argument(
        "--extended_prediction_data",
        action="store_true",
        help=
        "Write: Predicted string, labels; position, probabilities and alternatives of chars to a .pred (protobuf) file"
    )
    parser.add_argument(
        "--extended_prediction_data_format",
        type=str,
        default="json",
        help=
        "Extension format: Either pred or json. Note that json will not print logits."
    )
    parser.add_argument("--no_progress_bars",
                        action="store_true",
                        help="Do not show any progress bars")

    args = parser.parse_args()

    # checks
    if args.extended_prediction_data_format not in ["pred", "json"]:
        raise Exception(
            "Only 'pred' and 'json' are allowed extended prediction data formats"
        )

    # add json as extension, resolve wildcard, expand user, ... and remove .json again
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json")
                       for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]

    # create voter
    voter_params = VoterParams()
    voter_params.type = VoterParams.Type.Value(args.voter.upper())
    voter = voter_from_proto(voter_params)

    # load files
    input_image_files = sorted(glob_all(args.files))

    # skip invalid files, but keep then so that empty predictions are created
    dataset = FileDataSet(input_image_files,
                          skip_invalid=True,
                          remove_invalid=False)

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception(
            "Empty dataset provided. Check your files argument (got {})!".
            format(args.files))

    # predict for all models
    predictor = MultiPredictor(checkpoints=args.checkpoint)
    do_prediction = predictor.predict_dataset(
        dataset,
        batch_size=args.batch_size,
        processes=args.processes,
        progress_bar=not args.no_progress_bars)

    # output the voted results to the appropriate files
    for (result, sample), filepath in zip(do_prediction, input_image_files):
        for i, p in enumerate(result):
            p.prediction.id = "fold_{}".format(i)

        # vote the results (if only one model is given, this will just return the sentences)
        prediction = voter.vote_prediction_result(result)
        prediction.id = "voted"
        sentence = prediction.sentence
        if args.verbose:
            print("{}: '{}'".format(sample['id'], sentence))

        output_dir = args.output_dir if args.output_dir else os.path.dirname(
            filepath)

        with codecs.open(os.path.join(output_dir, sample['id'] + ".pred.txt"),
                         'w', 'utf-8') as f:
            f.write(sentence)

        if args.extended_prediction_data:
            ps = Predictions()
            ps.line_path = filepath
            ps.predictions.extend([prediction] +
                                  [r.prediction for r in result])
            if args.extended_prediction_data_format == "pred":
                with open(os.path.join(output_dir, sample['id'] + ".pred"),
                          'wb') as f:
                    f.write(ps.SerializeToString())
            elif args.extended_prediction_data_format == "json":
                with open(os.path.join(output_dir, sample['id'] + ".json"),
                          'w') as f:
                    # remove logits
                    for prediction in ps.predictions:
                        prediction.logits.rows = 0
                        prediction.logits.cols = 0
                        prediction.logits.data[:] = []

                    f.write(
                        MessageToJson(ps, including_default_value_fields=True))
            else:
                raise Exception("Unknown prediction format.")

    print("All files written")
예제 #11
0
def run(args):
    # checks
    if args.extended_prediction_data_format not in ["pred", "json"]:
        raise Exception(
            "Only 'pred' and 'json' are allowed extended prediction data formats"
        )

    # add json as extension, resolve wildcard, expand user, ... and remove .json again
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json")
                       for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]

    # create voter
    voter_params = VoterParams()
    voter_params.type = VoterParams.Type.Value(args.voter.upper())
    voter = voter_from_proto(voter_params)

    # load files
    files = glob.glob(args.files)
    dataset = AbbyyDataSet(files,
                           skip_invalid=True,
                           remove_invalid=False,
                           binary=args.binary)

    dataset.load_samples(processes=args.processes,
                         progress_bar=not args.no_progress_bars)

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception(
            "Empty dataset provided. Check your files argument (got {})!".
            format(args.files))

    # predict for all models
    predictor = MultiPredictor(checkpoints=args.checkpoint,
                               batch_size=args.batch_size,
                               processes=args.processes)
    do_prediction = predictor.predict_dataset(
        dataset, progress_bar=not args.no_progress_bars)

    # output the voted results to the appropriate files
    input_image_files = []

    # creat input_image_files list for next loop
    for page in dataset.book.pages:
        for fo in page.getFormats():
            input_image_files.append(page.imgFile)

    for (result, sample), filepath in zip(do_prediction, input_image_files):
        for i, p in enumerate(result):
            p.prediction.id = "fold_{}".format(i)

        # vote the results (if only one model is given, this will just return the sentences)
        prediction = voter.vote_prediction_result(result)
        prediction.id = "voted"
        sentence = prediction.sentence
        if args.verbose:
            lr = "\u202A\u202B"
            print("{}: '{}{}{}'".format(sample['id'],
                                        lr[get_base_level(sentence)], sentence,
                                        "\u202C"))

        output_dir = args.output_dir if args.output_dir else os.path.dirname(
            filepath)

        sample["format"].text = sentence

        if args.extended_prediction_data:
            ps = Predictions()
            ps.line_path = filepath
            ps.predictions.extend([prediction] +
                                  [r.prediction for r in result])
            if args.extended_prediction_data_format == "pred":
                with open(os.path.join(output_dir, sample['id'] + ".pred"),
                          'wb') as f:
                    f.write(ps.SerializeToString())
            elif args.extended_prediction_data_format == "json":
                with open(os.path.join(output_dir, sample['id'] + ".json"),
                          'w') as f:
                    # remove logits
                    for prediction in ps.predictions:
                        prediction.logits.rows = 0
                        prediction.logits.cols = 0
                        prediction.logits.data[:] = []

                    f.write(
                        MessageToJson(ps, including_default_value_fields=True))
            else:
                raise Exception("Unknown prediction format.")

    w = XMLWriter(output_dir, os.path.dirname(filepath), dataset.book)
    w.write()

    print("All files written")
예제 #12
0
파일: predict.py 프로젝트: AIRob/calamari
def run(args):
    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        import json
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    # checks
    if args.extended_prediction_data_format not in ["pred", "json"]:
        raise Exception("Only 'pred' and 'json' are allowed extended prediction data formats")

    # add json as extension, resolve wildcard, expand user, ... and remove .json again
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json") for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]

    # create voter
    voter_params = VoterParams()
    voter_params.type = VoterParams.Type.Value(args.voter.upper())
    voter = voter_from_proto(voter_params)

    # load files
    input_image_files = glob_all(args.files)
    if args.text_files:
        args.text_files = glob_all(args.text_files)

    # skip invalid files and remove them, there wont be predictions of invalid files
    dataset = create_dataset(
        args.dataset,
        DataSetMode.PREDICT,
        input_image_files,
        args.text_files,
        skip_invalid=True,
        remove_invalid=True,
        args={'text_index': args.pagexml_text_index},
    )

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception("Empty dataset provided. Check your files argument (got {})!".format(args.files))

    # predict for all models
    predictor = MultiPredictor(checkpoints=args.checkpoint, batch_size=args.batch_size, processes=args.processes)
    do_prediction = predictor.predict_dataset(dataset, progress_bar=not args.no_progress_bars)

    avg_sentence_confidence = 0
    n_predictions = 0

    # output the voted results to the appropriate files
    for result, sample in do_prediction:
        n_predictions += 1
        for i, p in enumerate(result):
            p.prediction.id = "fold_{}".format(i)

        # vote the results (if only one model is given, this will just return the sentences)
        prediction = voter.vote_prediction_result(result)
        prediction.id = "voted"
        sentence = prediction.sentence
        avg_sentence_confidence += prediction.avg_char_probability
        if args.verbose:
            lr = "\u202A\u202B"
            print("{}: '{}{}{}'".format(sample['id'], lr[get_base_level(sentence)], sentence, "\u202C" ))

        output_dir = args.output_dir

        dataset.store_text(sentence, sample, output_dir=output_dir, extension=".pred.txt")

        if args.extended_prediction_data:
            ps = Predictions()
            ps.line_path = sample['image_path'] if 'image_path' in sample else sample['id']
            ps.predictions.extend([prediction] + [r.prediction for r in result])
            output_dir = output_dir if output_dir else os.path.dirname(ps.line_path)
            if not os.path.exists(output_dir):
                os.mkdir(output_dir)

            if args.extended_prediction_data_format == "pred":
                with open(os.path.join(output_dir, sample['id'] + ".pred"), 'wb') as f:
                    f.write(ps.SerializeToString())
            elif args.extended_prediction_data_format == "json":
                with open(os.path.join(output_dir, sample['id'] + ".json"), 'w') as f:
                    # remove logits
                    for prediction in ps.predictions:
                        prediction.logits.rows = 0
                        prediction.logits.cols = 0
                        prediction.logits.data[:] = []

                    f.write(MessageToJson(ps, including_default_value_fields=True))
            else:
                raise Exception("Unknown prediction format.")

    print("Average sentence confidence: {:.2%}".format(avg_sentence_confidence / n_predictions))

    dataset.store()
    print("All files written")