Esempio n. 1
0
 def test_raw_prediction_voted(self):
     args = PredictionAttrs()
     predictor = MultiPredictor(checkpoints=args.checkpoint)
     images = [np.array(Image.open(file), dtype=np.uint8) for file in args.files]
     for file, image in zip(args.files, images):
         r = list(predictor.predict_raw([image], progress_bar=False))[0]
         print(file, [rn.sentence for rn in r])
Esempio n. 2
0
    def predict_books(self, books, models, pageupload=False, text_index=1):
        if type(books) == str:
            books = [books]
        if type(models) == str:
            models = [models]
        dset = Nash5DataSet(DataSetMode.PREDICT, self.cachefile, books)

        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value("confidence_voter_default_ctc".upper())
        voter = voter_from_proto(voter_params)

        # predict for all models
        predictor = MultiPredictor(checkpoints=models, data_preproc=NoopDataPreprocessor(), batch_size=1, processes=1)
        do_prediction = predictor.predict_dataset(dset, progress_bar=True)

        avg_sentence_confidence = 0
        n_predictions = 0
        # output the voted results to the appropriate files
        for result, sample in do_prediction:
            n_predictions += 1
            for i, p in enumerate(result):
                p.prediction.id = "fold_{}".format(i)

            # vote the results (if only one model is given, this will just return the sentences)
            prediction = voter.vote_prediction_result(result)
            prediction.id = "voted"
            sentence = prediction.sentence
            avg_sentence_confidence += prediction.avg_char_probability

            dset.store_text(sentence, sample, output_dir=None, extension=".pred.txt")
        print("Average sentence confidence: {:.2%}".format(avg_sentence_confidence / n_predictions))

        dset.store()
        print("All files written")
Esempio n. 3
0
	def _load_models(self):
		if self._predictor is not None:
			return

		if self._ocr == "FAKE":
			return

		batch_size = self._options["batch_size"]
		if batch_size > 0:
			batch_size_kwargs = dict(batch_size=batch_size)
		else:
			batch_size_kwargs = dict()
		self._chunk_size = batch_size

		if len(self._models) == 1:
			self._predictor = Predictor(
				str(self._models[0]), **batch_size_kwargs)
			self._predict_kwargs = batch_size_kwargs
			self._voter = None
			self._line_height = int(self._predictor.model_params.line_height)
		else:
			logging.info("using Calamari voting with %d models." % len(self._models))
			self._predictor = MultiPredictor(
				checkpoints=[str(p) for p in self._models],
				**batch_size_kwargs)
			self._predict_kwargs = dict()
			self._voter = ConfidenceVoter()
			self._line_height = int(self._predictor.predictors[0].model_params.line_height)
Esempio n. 4
0
    def _init_calamari(self):
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = TF_CPP_MIN_LOG_LEVEL

        checkpoints = glob(self.parameter['checkpoint'])
        self.predictor = MultiPredictor(checkpoints=checkpoints)

        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value(self.parameter['voter'].upper())
        self.voter = voter_from_proto(voter_params)
Esempio n. 5
0
    def predict_books(self, books, models, pageupload=True, text_index=1):
        if pageupload == False:
            print("""Warning: trying to save results to the hdf5-Cache may fail due to some issue
                  with file access from multiple threads. It should work, however, if you set
                  export HDF5_USE_FILE_LOCKING='FALSE'.""")
        if type(books) == str:
            books = [books]
        if type(models) == str:
            models = [models]
        dset = Nash5DataSet(DataSetMode.PREDICT, self.cachefile, books)

        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value("confidence_voter_default_ctc".upper())
        voter = voter_from_proto(voter_params)

        # predict for all models
        predictor = MultiPredictor(checkpoints=models, data_preproc=PadNoopDataPreprocessor(), batch_size=1, processes=1)
        do_prediction = predictor.predict_dataset(dset, progress_bar=True)

        avg_sentence_confidence = 0
        n_predictions = 0
        # output the voted results to the appropriate files
        for result, sample in do_prediction:
            n_predictions += 1
            for i, p in enumerate(result):
                p.prediction.id = "fold_{}".format(i)

            # vote the results (if only one model is given, this will just return the sentences)
            prediction = voter.vote_prediction_result(result)
            prediction.id = "voted"
            sentence = prediction.sentence
            avg_sentence_confidence += prediction.avg_char_probability

            dset.store_text(sentence, sample, output_dir=None, extension=".pred.txt")
        avg_conf = avg_sentence_confidence / n_predictions if n_predictions else 0
        print("Average sentence confidence: {:.2%}".format(avg_conf))

        if pageupload:
            ocrdata = {}
            for lname, text in dset.predictions.items():
                _, b, p, l = lname.split("/")
                if b not in ocrdata:
                    ocrdata[b] = {}
                if p not in ocrdata[b]:
                    ocrdata[b][p] = {}
                ocrdata[b][p][l] = text

            data = {"ocrdata": ocrdata, "index": text_index}
            self.session.post(self.baseurl+"/_ocrdata",
                              data=gzip.compress(json.dumps(data).encode("utf-8")),
                              headers={"Content-Type": "application/json;charset=UTF-8",
                                       "Content-Encoding": "gzip"})
            print("Results uploaded")
        else:
            dset.store()
            print("All files written")
Esempio n. 6
0
    def setup(self):
        """
        Set up the model prior to processing.
        """
        resolved = self.resolve_resource(self.parameter['checkpoint_dir'])
        checkpoints = glob('%s/*.ckpt.json' % resolved)
        self.predictor = MultiPredictor(checkpoints=checkpoints)

        self.network_input_channels = self.predictor.predictors[
            0].network.input_channels
        #self.network_input_channels = self.predictor.predictors[0].network_params.channels # not used!
        # binarization = self.predictor.predictors[0].model_params.data_preprocessor.binarization # not used!
        # self.features = ('' if self.network_input_channels != 1 else
        #                  'binarized' if binarization != 'GRAY' else
        #                  'grayscale_normalized')
        self.features = ''

        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value(
            self.parameter['voter'].upper())
        self.voter = voter_from_proto(voter_params)
Esempio n. 7
0
    def evaluate_books(self, books, models, mode="auto", sample=-1):
        if type(books) == str:
            books = [books]
        if type(models) == str:
            models = [models]
        results = {}
        if mode == "auto":
            with h5py.File(self.cachefile, 'r', libver='latest', swmr=True) as cache:
                for b in books:
                    for p in cache[b]:
                        for s in cache[b][p]:
                            if "text" in cache[b][p][s].attrs:
                                mode = "eval"
                                break
                        if mode != "auto":
                            break
                    if mode != "auto":
                        break
            if mode == "auto":
                mode = "conf"

        if mode == "conf":
            dset = Nash5DataSet(DataSetMode.PREDICT, self.cachefile, books)
        else:
            dset = Nash5DataSet(DataSetMode.EVAL, self.cachefile, books)

        if 0 < sample < len(dset):
            delsamples = random.sample(dset._samples, len(dset) - sample)
            for s in delsamples:
                dset._samples.remove(s)

        if mode == "conf":
            for model in models:
                if isinstance(model, str):
                    model = [model]
                predictor = MultiPredictor(checkpoints=model, data_preproc=NoopDataPreprocessor(), batch_size=1, processes=1)
                voter_params = VoterParams()
                voter_params.type = VoterParams.Type.Value("confidence_voter_default_ctc".upper())
                voter = voter_from_proto(voter_params)
                do_prediction = predictor.predict_dataset(dset, progress_bar=True)
                avg_sentence_confidence = 0
                n_predictions = 0
                for result, sample in do_prediction:
                    n_predictions += 1
                    prediction = voter.vote_prediction_result(result)
                    avg_sentence_confidence += prediction.avg_char_probability
                results["/".join(model)] = avg_sentence_confidence / n_predictions

        else:
            for model in models:
                if isinstance(model, str):
                    model = [model]
                predictor = MultiPredictor(checkpoint=model, data_preproc=NoopDataPreprocessor(), batch_size=1, processes=1, with_gt=True)
                out_gen = predictor.predict_dataset(dset, progress_bar=True, apply_preproc=False)
                result = Evaluator.evaluate_single_list(map(Evaluator.evaluate_single_args,
                            map(lambda d: tuple([''.join(d[0].ground_truth), ''.join(d[0].chars)]), out_gen)))
                results["/".join(model)] = 1 - result["avg_ler"]
        return results
Esempio n. 8
0
class CalamariRecognize(Processor):
    def __init__(self, *args, **kwargs):
        kwargs['ocrd_tool'] = OCRD_TOOL['tools'][TOOL]
        kwargs['version'] = '%s (calamari %s, tensorflow %s)' % (
            OCRD_TOOL['version'], calamari_version, tensorflow_version)
        super(CalamariRecognize, self).__init__(*args, **kwargs)
        if hasattr(self, 'output_file_grp'):
            # processing context
            self.setup()

    def setup(self):
        """
        Set up the model prior to processing.
        """
        resolved = self.resolve_resource(self.parameter['checkpoint_dir'])
        checkpoints = glob('%s/*.ckpt.json' % resolved)
        self.predictor = MultiPredictor(checkpoints=checkpoints)

        self.network_input_channels = self.predictor.predictors[
            0].network.input_channels
        #self.network_input_channels = self.predictor.predictors[0].network_params.channels # not used!
        # binarization = self.predictor.predictors[0].model_params.data_preprocessor.binarization # not used!
        # self.features = ('' if self.network_input_channels != 1 else
        #                  'binarized' if binarization != 'GRAY' else
        #                  'grayscale_normalized')
        self.features = ''

        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value(
            self.parameter['voter'].upper())
        self.voter = voter_from_proto(voter_params)

    def process(self):
        """
        Perform text recognition with Calamari on the workspace.

        If ``texequiv_level`` is ``word`` or ``glyph``, then additionally create word / glyph level segments by
        splitting at white space characters / glyph boundaries. In the case of ``glyph``, add all alternative character
        hypotheses down to ``glyph_conf_cutoff`` confidence threshold.
        """
        log = getLogger('processor.CalamariRecognize')

        assert_file_grp_cardinality(self.input_file_grp, 1)
        assert_file_grp_cardinality(self.output_file_grp, 1)

        for (n, input_file) in enumerate(self.input_files):
            page_id = input_file.pageId or input_file.ID
            log.info("INPUT FILE %i / %s", n, page_id)
            pcgts = page_from_file(self.workspace.download_file(input_file))

            page = pcgts.get_Page()
            page_image, page_coords, page_image_info = self.workspace.image_from_page(
                page, page_id, feature_selector=self.features)

            for region in page.get_AllRegions(classes=['Text']):
                region_image, region_coords = self.workspace.image_from_segment(
                    region,
                    page_image,
                    page_coords,
                    feature_selector=self.features)

                textlines = region.get_TextLine()
                log.info("About to recognize %i lines of region '%s'",
                         len(textlines), region.id)
                line_images_np = []
                line_coordss = []
                for line in textlines:
                    log.debug("Recognizing line '%s' in region '%s'", line.id,
                              region.id)

                    line_image, line_coords = self.workspace.image_from_segment(
                        line,
                        region_image,
                        region_coords,
                        feature_selector=self.features)
                    if ('binarized' not in line_coords['features']
                            and 'grayscale_normalized'
                            not in line_coords['features']
                            and self.network_input_channels == 1):
                        # We cannot use a feature selector for this since we don't
                        # know whether the model expects (has been trained on)
                        # binarized or grayscale images; but raw images are likely
                        # always inadequate:
                        log.warning(
                            "Using raw image for line '%s' in region '%s'",
                            line.id, region.id)

                    line_image = line_image if all(line_image.size) else [[0]]
                    line_image_np = np.array(line_image, dtype=np.uint8)
                    line_images_np.append(line_image_np)
                    line_coordss.append(line_coords)
                raw_results_all = self.predictor.predict_raw(
                    line_images_np, progress_bar=False)

                for line, line_coords, raw_results in zip(
                        textlines, line_coordss, raw_results_all):

                    for i, p in enumerate(raw_results):
                        p.prediction.id = "fold_{}".format(i)

                    prediction = self.voter.vote_prediction_result(raw_results)
                    prediction.id = "voted"

                    # Build line text on our own
                    #
                    # Calamari does whitespace post-processing on prediction.sentence, while it does not do the same
                    # on prediction.positions. Do it on our own to have consistency.
                    #
                    # XXX Check Calamari's built-in post-processing on prediction.sentence

                    def _sort_chars(p):
                        """Filter and sort chars of prediction p"""
                        chars = p.chars
                        chars = [
                            c for c in chars if c.char
                        ]  # XXX Note that omission probabilities are not normalized?!
                        chars = [
                            c for c in chars if c.probability >=
                            self.parameter['glyph_conf_cutoff']
                        ]
                        chars = sorted(chars,
                                       key=lambda k: k.probability,
                                       reverse=True)
                        return chars

                    def _drop_leading_spaces(positions):
                        return list(
                            itertools.dropwhile(
                                lambda p: _sort_chars(p)[0].char == " ",
                                positions))

                    def _drop_trailing_spaces(positions):
                        return list(
                            reversed(_drop_leading_spaces(
                                reversed(positions))))

                    def _drop_double_spaces(positions):
                        def _drop_double_spaces_generator(positions):
                            last_was_space = False
                            for p in positions:
                                if p.chars[0].char == " ":
                                    if not last_was_space:
                                        yield p
                                    last_was_space = True
                                else:
                                    yield p
                                    last_was_space = False

                        return list(_drop_double_spaces_generator(positions))

                    positions = prediction.positions
                    positions = _drop_leading_spaces(positions)
                    positions = _drop_trailing_spaces(positions)
                    positions = _drop_double_spaces(positions)
                    positions = list(positions)

                    line_text = ''.join(
                        _sort_chars(p)[0].char for p in positions)
                    if line_text != prediction.sentence:
                        log.warning(
                            "Our own line text is not the same as Calamari's: '%s' != '%s'",
                            line_text, prediction.sentence)

                    # Delete existing results
                    if line.get_TextEquiv():
                        log.warning("Line '%s' already contained text results",
                                    line.id)
                    line.set_TextEquiv([])
                    if line.get_Word():
                        log.warning(
                            "Line '%s' already contained word segmentation",
                            line.id)
                    line.set_Word([])

                    # Save line results
                    line_conf = prediction.avg_char_probability
                    line.set_TextEquiv(
                        [TextEquivType(Unicode=line_text, conf=line_conf)])

                    # Save word results
                    #
                    # Calamari OCR does not provide word positions, so we infer word positions from a. text segmentation
                    # and b. the glyph positions. This is necessary because the PAGE XML format enforces a strict
                    # hierarchy of lines > words > glyphs.

                    def _words(s):
                        """Split words based on spaces and include spaces as 'words'"""
                        spaces = None
                        word = ''
                        for c in s:
                            if c == ' ' and spaces is True:
                                word += c
                            elif c != ' ' and spaces is False:
                                word += c
                            else:
                                if word:
                                    yield word
                                word = c
                                spaces = (c == ' ')
                        yield word

                    if self.parameter['textequiv_level'] in ['word', 'glyph']:
                        word_no = 0
                        i = 0

                        for word_text in _words(line_text):
                            word_length = len(word_text)
                            if not all(c == ' ' for c in word_text):
                                word_positions = positions[i:i + word_length]
                                word_start = word_positions[0].global_start
                                word_end = word_positions[-1].global_end

                                polygon = polygon_from_x0y0x1y1([
                                    word_start, 0, word_end, line_image.height
                                ])
                                points = points_from_polygon(
                                    coordinates_for_segment(
                                        polygon, None, line_coords))
                                # XXX Crop to line polygon?

                                word = WordType(id='%s_word%04d' %
                                                (line.id, word_no),
                                                Coords=CoordsType(points))
                                word.add_TextEquiv(
                                    TextEquivType(Unicode=word_text))

                                if self.parameter[
                                        'textequiv_level'] == 'glyph':
                                    for glyph_no, p in enumerate(
                                            word_positions):
                                        glyph_start = p.global_start
                                        glyph_end = p.global_end

                                        polygon = polygon_from_x0y0x1y1([
                                            glyph_start, 0, glyph_end,
                                            line_image.height
                                        ])
                                        points = points_from_polygon(
                                            coordinates_for_segment(
                                                polygon, None, line_coords))

                                        glyph = GlyphType(
                                            id='%s_glyph%04d' %
                                            (word.id, glyph_no),
                                            Coords=CoordsType(points))

                                        # Add predictions (= TextEquivs)
                                        char_index_start = 1  # Must start with 1, see https://ocr-d.github.io/page#multiple-textequivs
                                        for char_index, char in enumerate(
                                                _sort_chars(p),
                                                start=char_index_start):
                                            glyph.add_TextEquiv(
                                                TextEquivType(
                                                    Unicode=char.char,
                                                    index=char_index,
                                                    conf=char.probability))

                                        word.add_Glyph(glyph)

                                line.add_Word(word)
                                word_no += 1

                            i += word_length

            _page_update_higher_textequiv_levels('line', pcgts)

            # Add metadata about this operation and its runtime parameters:
            self.add_metadata(pcgts)
            file_id = make_file_id(input_file, self.output_file_grp)
            pcgts.set_pcGtsId(file_id)
            self.workspace.add_file(ID=file_id,
                                    file_grp=self.output_file_grp,
                                    pageId=input_file.pageId,
                                    mimetype=MIMETYPE_PAGE,
                                    local_filename=os.path.join(
                                        self.output_file_grp,
                                        file_id + '.xml'),
                                    content=to_xml(pcgts))
Esempio n. 9
0
    def evaluate_books(self, books, models, rtl=False, mode="auto", sample=-1):
        if type(books) == str:
            books = [books]
        if type(models) == str:
            models = [models]
        results = {}
        if mode == "auto":
            with h5py.File(self.cachefile, 'r', libver='latest',
                           swmr=True) as cache:
                for b in books:
                    for p in cache[b]:
                        for s in cache[b][p]:
                            if "text" in cache[b][p][s].attrs:
                                mode = "eval"
                                break
                        if mode != "auto":
                            break
                    if mode != "auto":
                        break
            if mode == "auto":
                mode = "conf"

        if mode == "conf":
            dset = Nash5DataSet(DataSetMode.PREDICT, self.cachefile, books)
        else:
            dset = Nash5DataSet(DataSetMode.TRAIN, self.cachefile, books)
            dset.mode = DataSetMode.PREDICT  # otherwise results are randomised

        if 0 < sample < len(dset):
            delsamples = random.sample(dset._samples, len(dset) - sample)
            for s in delsamples:
                dset._samples.remove(s)

        if mode == "conf":
            #dset = dset.to_raw_input_dataset(processes=1, progress_bar=True)
            for model in models:
                if isinstance(model, str):
                    model = [model]
                predictor = MultiPredictor(checkpoints=model,
                                           data_preproc=NoopDataPreprocessor(),
                                           batch_size=1,
                                           processes=1)
                voter_params = VoterParams()
                voter_params.type = VoterParams.Type.Value(
                    "confidence_voter_default_ctc".upper())
                voter = voter_from_proto(voter_params)
                do_prediction = predictor.predict_dataset(dset,
                                                          progress_bar=True)
                avg_sentence_confidence = 0
                n_predictions = 0
                for result, sample in do_prediction:
                    n_predictions += 1
                    prediction = voter.vote_prediction_result(result)
                    avg_sentence_confidence += prediction.avg_char_probability

                results["/".join(
                    model)] = avg_sentence_confidence / n_predictions

        else:
            for model in models:
                if isinstance(model, str):
                    model = [model]

                predictor = MultiPredictor(checkpoints=model,
                                           data_preproc=NoopDataPreprocessor(),
                                           batch_size=1,
                                           processes=1)

                voter_params = VoterParams()
                voter_params.type = VoterParams.Type.Value(
                    "confidence_voter_default_ctc".upper())
                voter = voter_from_proto(voter_params)

                out_gen = predictor.predict_dataset(dset, progress_bar=True)

                preproc = self.bidi_preproc if rtl else self.txt_preproc

                pred_dset = RawDataSet(DataSetMode.EVAL,
                                       texts=preproc.apply([
                                           voter.vote_prediction_result(
                                               d[0]).sentence for d in out_gen
                                       ]))

                evaluator = Evaluator(text_preprocessor=NoopTextProcessor(),
                                      skip_empty_gt=False)
                r = evaluator.run(gt_dataset=dset,
                                  pred_dataset=pred_dset,
                                  processes=1,
                                  progress_bar=True)

                results["/".join(model)] = 1 - r["avg_ler"]
        return results
Esempio n. 10
0
def run(args):

    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        import json
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    # checks
    if args.extended_prediction_data_format not in ["pred", "json"]:
        raise Exception(
            "Only 'pred' and 'json' are allowed extended prediction data formats"
        )

    # add json as extension, resolve wildcard, expand user, ... and remove .json again
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json")
                       for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]

    # create voter
    voter_params = VoterParams()
    voter_params.type = VoterParams.Type.Value(args.voter.upper())
    voter = voter_from_proto(voter_params)

    # load files
    input_image_files = glob_all(args.files)
    if args.text_files:
        args.text_files = glob_all(args.text_files)

    # skip invalid files and remove them, there wont be predictions of invalid files
    dataset = create_dataset(
        args.dataset,
        DataSetMode.PREDICT,
        input_image_files,
        args.text_files,
        skip_invalid=True,
        remove_invalid=True,
        args={'text_index': args.pagexml_text_index},
    )

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception(
            "Empty dataset provided. Check your files argument (got {})!".
            format(args.files))

    # predict for all models
    predictor = MultiPredictor(checkpoints=args.checkpoint,
                               batch_size=args.batch_size,
                               processes=args.processes)
    do_prediction = predictor.predict_dataset(
        dataset, progress_bar=not args.no_progress_bars)

    avg_sentence_confidence = 0
    n_predictions = 0

    # output the voted results to the appropriate files
    for result, sample in do_prediction:
        n_predictions += 1
        for i, p in enumerate(result):
            p.prediction.id = "fold_{}".format(i)

        # vote the results (if only one model is given, this will just return the sentences)
        prediction = voter.vote_prediction_result(result)
        prediction.id = "voted"
        sentence = prediction.sentence
        avg_sentence_confidence += prediction.avg_char_probability
        if args.verbose:
            lr = "\u202A\u202B"
            print("{}: '{}{}{}'".format(sample['id'],
                                        lr[get_base_level(sentence)], sentence,
                                        "\u202C"))

        output_dir = args.output_dir

        dataset.store_text(sentence,
                           sample,
                           output_dir=output_dir,
                           extension=".pred.txt")

        if args.extended_prediction_data:
            ps = Predictions()
            ps.line_path = sample[
                'image_path'] if 'image_path' in sample else sample['id']
            ps.predictions.extend([prediction] +
                                  [r.prediction for r in result])
            output_dir = output_dir if output_dir else os.path.dirname(
                ps.line_path)
            if not os.path.exists(output_dir):
                os.mkdir(output_dir)

            if args.extended_prediction_data_format == "pred":
                with open(os.path.join(output_dir, sample['id'] + ".pred"),
                          'wb') as f:
                    f.write(ps.SerializeToString())
            elif args.extended_prediction_data_format == "json":
                with open(os.path.join(output_dir, sample['id'] + ".json"),
                          'w') as f:
                    # remove logits
                    for prediction in ps.predictions:
                        prediction.logits.rows = 0
                        prediction.logits.cols = 0
                        prediction.logits.data[:] = []

                    f.write(
                        MessageToJson(ps, including_default_value_fields=True))
            else:
                raise Exception("Unknown prediction format.")

    print("Average sentence confidence: {:.2%}".format(
        avg_sentence_confidence / n_predictions))

    dataset.store()
    print("All files written")
Esempio n. 11
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--eval_imgs", type=str, nargs="+", required=True,
                        help="The evaluation files")
    parser.add_argument("--eval_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--checkpoint", type=str, nargs="+", default=[],
                        help="Path to the checkpoint without file extension")
    parser.add_argument("-j", "--processes", type=int, default=1,
                        help="Number of processes to use")
    parser.add_argument("--verbose", action="store_true",
                        help="Print additional information")
    parser.add_argument("--voter", type=str, nargs="+", default=["sequence_voter", "confidence_voter_default_ctc", "confidence_voter_fuzzy_ctc"],
                        help="The voting algorithm to use. Possible values: confidence_voter_default_ctc (default), "
                             "confidence_voter_fuzzy_ctc, sequence_voter")
    parser.add_argument("--batch_size", type=int, default=10,
                        help="The batch size for prediction")
    parser.add_argument("--dump", type=str,
                        help="Dump the output as serialized pickle object")
    parser.add_argument("--no_skip_invalid_gt", action="store_true",
                        help="Do no skip invalid gt, instead raise an exception.")

    args = parser.parse_args()

    # allow user to specify json file for model definition, but remove the file extension
    # for further processing
    args.checkpoint = [(cp[:-5] if cp.endswith(".json") else cp) for cp in args.checkpoint]

    # load files
    gt_images = sorted(glob_all(args.eval_imgs))
    gt_txts = [split_all_ext(path)[0] + ".gt.txt" for path in sorted(glob_all(args.eval_imgs))]

    dataset = create_dataset(
        args.eval_dataset,
        DataSetMode.TRAIN,
        images=gt_images,
        texts=gt_txts,
        skip_invalid=not args.no_skip_invalid_gt
    )

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception("Empty dataset provided. Check your files argument (got {})!".format(args.files))

    # predict for all models
    n_models = len(args.checkpoint)
    predictor = MultiPredictor(checkpoints=args.checkpoint, batch_size=args.batch_size, processes=args.processes)
    do_prediction = predictor.predict_dataset(dataset, progress_bar=True)

    voters = []
    all_voter_sentences = []
    all_prediction_sentences = [[] for _ in range(n_models)]

    for voter in args.voter:
        # create voter
        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value(voter.upper())
        voters.append(voter_from_proto(voter_params))
        all_voter_sentences.append([])

    for prediction, sample in do_prediction:
        for sent, p in zip(all_prediction_sentences, prediction):
            sent.append(p.sentence)

        # vote results
        for voter, voter_sentences in zip(voters, all_voter_sentences):
            voter_sentences.append(voter.vote_prediction_result(prediction).sentence)

    # evaluation
    text_preproc = text_processor_from_proto(predictor.predictors[0].model_params.text_preprocessor)
    evaluator = Evaluator(text_preprocessor=text_preproc)
    evaluator.preload_gt(gt_dataset=dataset, progress_bar=True)

    def single_evaluation(predicted_sentences):
        if len(predicted_sentences) != len(dataset):
            raise Exception("Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did "
                            "not succeed".format(len(dataset), len(predicted_sentences)))

        pred_data_set = create_dataset(
            DataSetType.RAW,
            DataSetMode.EVAL,
            texts=predicted_sentences)

        r = evaluator.run(pred_dataset=pred_data_set, progress_bar=True, processes=args.processes)

        return r

    full_evaluation = {}
    for id, data in [(str(i), sent) for i, sent in enumerate(all_prediction_sentences)] + list(zip(args.voter, all_voter_sentences)):
        full_evaluation[id] = {"eval": single_evaluation(data), "data": data}

    if args.verbose:
        print(full_evaluation)

    if args.dump:
        import pickle
        with open(args.dump, 'wb') as f:
            pickle.dump({"full": full_evaluation, "gt_txts": gt_txts, "gt": dataset.text_samples()}, f)
Esempio n. 12
0
def run(args):
    # checks
    if args.extended_prediction_data_format not in ["pred", "json"]:
        raise Exception(
            "Only 'pred' and 'json' are allowed extended prediction data formats"
        )

    # add json as extension, resolve wildcard, expand user, ... and remove .json again
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json")
                       for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]

    # create voter
    voter_params = VoterParams()
    voter_params.type = VoterParams.Type.Value(args.voter.upper())
    voter = voter_from_proto(voter_params)

    # load files
    files = glob.glob(args.files)
    dataset = AbbyyDataSet(files,
                           skip_invalid=True,
                           remove_invalid=False,
                           binary=args.binary)

    dataset.load_samples(processes=args.processes,
                         progress_bar=not args.no_progress_bars)

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception(
            "Empty dataset provided. Check your files argument (got {})!".
            format(args.files))

    # predict for all models
    predictor = MultiPredictor(checkpoints=args.checkpoint,
                               batch_size=args.batch_size,
                               processes=args.processes)
    do_prediction = predictor.predict_dataset(
        dataset, progress_bar=not args.no_progress_bars)

    # output the voted results to the appropriate files
    input_image_files = []

    # creat input_image_files list for next loop
    for page in dataset.book.pages:
        for fo in page.getFormats():
            input_image_files.append(page.imgFile)

    for (result, sample), filepath in zip(do_prediction, input_image_files):
        for i, p in enumerate(result):
            p.prediction.id = "fold_{}".format(i)

        # vote the results (if only one model is given, this will just return the sentences)
        prediction = voter.vote_prediction_result(result)
        prediction.id = "voted"
        sentence = prediction.sentence
        if args.verbose:
            lr = "\u202A\u202B"
            print("{}: '{}{}{}'".format(sample['id'],
                                        lr[get_base_level(sentence)], sentence,
                                        "\u202C"))

        output_dir = args.output_dir if args.output_dir else os.path.dirname(
            filepath)

        sample["format"].text = sentence

        if args.extended_prediction_data:
            ps = Predictions()
            ps.line_path = filepath
            ps.predictions.extend([prediction] +
                                  [r.prediction for r in result])
            if args.extended_prediction_data_format == "pred":
                with open(os.path.join(output_dir, sample['id'] + ".pred"),
                          'wb') as f:
                    f.write(ps.SerializeToString())
            elif args.extended_prediction_data_format == "json":
                with open(os.path.join(output_dir, sample['id'] + ".json"),
                          'w') as f:
                    # remove logits
                    for prediction in ps.predictions:
                        prediction.logits.rows = 0
                        prediction.logits.cols = 0
                        prediction.logits.data[:] = []

                    f.write(
                        MessageToJson(ps, including_default_value_fields=True))
            else:
                raise Exception("Unknown prediction format.")

    w = XMLWriter(output_dir, os.path.dirname(filepath), dataset.book)
    w.write()

    print("All files written")
Esempio n. 13
0
class OCRProcessor(Processor):
	def __init__(self, options):
		super().__init__(options)
		self._options = options
		self._ocr = self._options["ocr"]

		if self._ocr == "FAKE":
			self._model_path = None
			self._models = []
			self._line_height = 48
			self._chunk_size = 1
		else:
			if not options["model"]:
				raise click.BadParameter(
					"Please specify a model path", param="model")
			self._model_path = Path(options["model"])

			models = list(self._model_path.glob("*.json"))
			if not options["legacy_model"]:
				models = [m for m in models if m.with_suffix(".h5").exists()]
			if len(models) < 1:
				raise FileNotFoundError(
					"no Calamari models found at %s" % self._model_path)
			self._models = models

			self._line_height = None
			self._chunk_size = None

		self._predictor = None
		self._voter = None

		self._ignored = RegionsFilter(options["ignore"])

		if self._ocr != "FULL":
			logging.getLogger().setLevel(logging.INFO)

	@property
	def processor_name(self):
		return __loader__.name

	def _load_models(self):
		if self._predictor is not None:
			return

		if self._ocr == "FAKE":
			return

		batch_size = self._options["batch_size"]
		if batch_size > 0:
			batch_size_kwargs = dict(batch_size=batch_size)
		else:
			batch_size_kwargs = dict()
		self._chunk_size = batch_size

		if len(self._models) == 1:
			self._predictor = Predictor(
				str(self._models[0]), **batch_size_kwargs)
			self._predict_kwargs = batch_size_kwargs
			self._voter = None
			self._line_height = int(self._predictor.model_params.line_height)
		else:
			logging.info("using Calamari voting with %d models." % len(self._models))
			self._predictor = MultiPredictor(
				checkpoints=[str(p) for p in self._models],
				**batch_size_kwargs)
			self._predict_kwargs = dict()
			self._voter = ConfidenceVoter()
			self._line_height = int(self._predictor.predictors[0].model_params.line_height)

	def artifacts(self):
		return [
			("reliable", Input(
				Artifact.LINES, Artifact.TABLES,
				stage=Stage.RELIABLE)),
			("output", Output(Artifact.OCR)),
		]

	def process(self, page_path: Path, reliable, output):
		self._load_models()

		lines = reliable.lines.by_path

		extractor = LineExtractor(
			reliable.tables,
			self._line_height,
			self._options,
			min_confidence=reliable.lines.min_confidence)

		min_width = 6
		min_height = 6

		names = []
		empty_names = []
		images = []
		for stem, im in extractor(lines, ignored=self._ignored):
			if im.width >= min_width and im.height >= min_height:
				names.append("/".join(stem))
				images.append(np.array(im))
			else:
				empty_names.append("/".join(stem))

		if self._ocr == "DRY":
			logging.info("will ocr the following lines:\n%s" % "\n".join(sorted(names)))
			return

		chunk_size = self._chunk_size
		if chunk_size <= 0:
			chunk_size = len(images)

		texts = []

		if self._ocr == "FAKE":
			for name in names:
				texts.append("text for %s." % name)
		else:
			for i in range(0, len(images), chunk_size):
				for prediction in self._predictor.predict_raw(
					images[i:i + chunk_size], progress_bar=False, **self._predict_kwargs):

					if self._voter is not None:
						prediction = self._voter.vote_prediction_result(prediction)
					texts.append(prediction.sentence)

		with output.ocr() as zf:
			for name, text in zip(names, texts):
				zf.writestr("%s.txt" % name, text)
			for name in empty_names:
				zf.writestr("%s.txt" % name, "")
Esempio n. 14
0
class CalamariRecognize(Processor):

    def __init__(self, *args, **kwargs):
        kwargs['ocrd_tool'] = OCRD_TOOL['tools']['ocrd-calamari-recognize']
        super(CalamariRecognize, self).__init__(*args, **kwargs)

    def _init_calamari(self):
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = TF_CPP_MIN_LOG_LEVEL

        checkpoints = glob(self.parameter['checkpoint'])
        self.predictor = MultiPredictor(checkpoints=checkpoints)

        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value(self.parameter['voter'].upper())
        self.voter = voter_from_proto(voter_params)

    def _make_file_id(self, input_file, n):
        file_id = input_file.ID.replace(self.input_file_grp, self.output_file_grp)
        if file_id == input_file.ID:
            file_id = concat_padded(self.output_file_grp, n)
        return file_id

    def process(self):
        """
        Performs the recognition.
        """

        self._init_calamari()

        for (n, input_file) in enumerate(self.input_files):
            page_id = input_file.pageId or input_file.ID
            log.info("INPUT FILE %i / %s", n, page_id)
            pcgts = page_from_file(self.workspace.download_file(input_file))

            page = pcgts.get_Page()
            page_image, page_xywh, page_image_info = self.workspace.image_from_page(page, page_id)

            for region in pcgts.get_Page().get_TextRegion():
                region_image, region_xywh = self.workspace.image_from_segment(region, page_image, page_xywh)

                textlines = region.get_TextLine()
                log.info("About to recognize %i lines of region '%s'", len(textlines), region.id)
                for (line_no, line) in enumerate(textlines):
                    log.debug("Recognizing line '%s' in region '%s'", line_no, region.id)

                    line_image, line_xywh = self.workspace.image_from_segment(line, region_image, region_xywh)
                    line_image_np = np.array(line_image, dtype=np.uint8)

                    raw_results = list(self.predictor.predict_raw([line_image_np], progress_bar=False))[0]
                    for i, p in enumerate(raw_results):
                        p.prediction.id = "fold_{}".format(i)

                    prediction = self.voter.vote_prediction_result(raw_results)
                    prediction.id = "voted"

                    line_text = prediction.sentence
                    line_conf = prediction.avg_char_probability

                    line.set_TextEquiv([TextEquivType(Unicode=line_text, conf=line_conf)])

            _page_update_higher_textequiv_levels('line', pcgts)

            file_id = self._make_file_id(input_file, n)
            self.workspace.add_file(
                ID=file_id,
                file_grp=self.output_file_grp,
                pageId=input_file.pageId,
                mimetype=MIMETYPE_PAGE,
                local_filename=os.path.join(self.output_file_grp, file_id + '.xml'),
                content=to_xml(pcgts))
Esempio n. 15
0
class CalamariRecognize(Processor):
    def __init__(self, *args, **kwargs):
        kwargs['ocrd_tool'] = OCRD_TOOL['tools'][TOOL]
        kwargs['version'] = OCRD_TOOL['version']
        super(CalamariRecognize, self).__init__(*args, **kwargs)

    def _init_calamari(self):
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = TF_CPP_MIN_LOG_LEVEL

        checkpoints = glob(self.parameter['checkpoint'])
        self.predictor = MultiPredictor(checkpoints=checkpoints)

        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value(
            self.parameter['voter'].upper())
        self.voter = voter_from_proto(voter_params)

    def process(self):
        """
        Performs the recognition.
        """

        assert_file_grp_cardinality(self.input_file_grp, 1)
        assert_file_grp_cardinality(self.output_file_grp, 1)

        self._init_calamari()

        for (n, input_file) in enumerate(self.input_files):
            page_id = input_file.pageId or input_file.ID
            log.info("INPUT FILE %i / %s", n, page_id)
            pcgts = page_from_file(self.workspace.download_file(input_file))

            page = pcgts.get_Page()
            page_image, page_xywh, page_image_info = self.workspace.image_from_page(
                page, page_id)

            for region in pcgts.get_Page().get_TextRegion():
                region_image, region_xywh = self.workspace.image_from_segment(
                    region, page_image, page_xywh)

                textlines = region.get_TextLine()
                log.info("About to recognize %i lines of region '%s'",
                         len(textlines), region.id)
                for (line_no, line) in enumerate(textlines):
                    log.debug("Recognizing line '%s' in region '%s'", line.id,
                              region.id)

                    line_image, line_coords = self.workspace.image_from_segment(
                        line, region_image, region_xywh)
                    line_image_np = np.array(line_image, dtype=np.uint8)

                    raw_results = list(
                        self.predictor.predict_raw([line_image_np],
                                                   progress_bar=False))[0]
                    for i, p in enumerate(raw_results):
                        p.prediction.id = "fold_{}".format(i)

                    prediction = self.voter.vote_prediction_result(raw_results)
                    prediction.id = "voted"

                    # Build line text on our own
                    #
                    # Calamari does whitespace post-processing on prediction.sentence, while it does not do the same
                    # on prediction.positions. Do it on our own to have consistency.
                    #
                    # XXX Check Calamari's built-in post-processing on prediction.sentence

                    def _sort_chars(p):
                        """Filter and sort chars of prediction p"""
                        chars = p.chars
                        chars = [
                            c for c in chars if c.char
                        ]  # XXX Note that omission probabilities are not normalized?!
                        chars = [
                            c for c in chars if c.probability >=
                            self.parameter['glyph_conf_cutoff']
                        ]
                        chars = sorted(chars,
                                       key=lambda k: k.probability,
                                       reverse=True)
                        return chars

                    def _drop_leading_spaces(positions):
                        return list(
                            itertools.dropwhile(
                                lambda p: _sort_chars(p)[0].char == " ",
                                positions))

                    def _drop_trailing_spaces(positions):
                        return list(
                            reversed(_drop_leading_spaces(
                                reversed(positions))))

                    def _drop_double_spaces(positions):
                        def _drop_double_spaces_generator(positions):
                            last_was_space = False
                            for p in positions:
                                if p.chars[0].char == " ":
                                    if not last_was_space:
                                        yield p
                                    last_was_space = True
                                else:
                                    yield p
                                    last_was_space = False

                        return list(_drop_double_spaces_generator(positions))

                    positions = prediction.positions
                    positions = _drop_leading_spaces(positions)
                    positions = _drop_trailing_spaces(positions)
                    positions = _drop_double_spaces(positions)
                    positions = list(positions)

                    line_text = ''.join(
                        _sort_chars(p)[0].char for p in positions)
                    if line_text != prediction.sentence:
                        log.warning(
                            "Our own line text is not the same as Calamari's: '%s' != '%s'",
                            line_text, prediction.sentence)

                    # Delete existing results
                    if line.get_TextEquiv():
                        log.warning("Line '%s' already contained text results",
                                    line.id)
                    line.set_TextEquiv([])
                    if line.get_Word():
                        log.warning(
                            "Line '%s' already contained word segmentation",
                            line.id)
                    line.set_Word([])

                    # Save line results
                    line_conf = prediction.avg_char_probability
                    line.set_TextEquiv(
                        [TextEquivType(Unicode=line_text, conf=line_conf)])

                    # Save word results
                    #
                    # Calamari OCR does not provide word positions, so we infer word positions from a. text segmentation
                    # and b. the glyph positions. This is necessary because the PAGE XML format enforces a strict
                    # hierarchy of lines > words > glyphs.

                    def _words(s):
                        """Split words based on spaces and include spaces as 'words'"""
                        spaces = None
                        word = ''
                        for c in s:
                            if c == ' ' and spaces is True:
                                word += c
                            elif c != ' ' and spaces is False:
                                word += c
                            else:
                                if word:
                                    yield word
                                word = c
                                spaces = (c == ' ')
                        yield word

                    if self.parameter['textequiv_level'] in ['word', 'glyph']:
                        word_no = 0
                        i = 0

                        for word_text in _words(line_text):
                            word_length = len(word_text)
                            if not all(c == ' ' for c in word_text):
                                word_positions = positions[i:i + word_length]
                                word_start = word_positions[0].global_start
                                word_end = word_positions[-1].global_end

                                polygon = polygon_from_x0y0x1y1([
                                    word_start, 0, word_end, line_image.height
                                ])
                                points = points_from_polygon(
                                    coordinates_for_segment(
                                        polygon, None, line_coords))
                                # XXX Crop to line polygon?

                                word = WordType(id='%s_word%04d' %
                                                (line.id, word_no),
                                                Coords=CoordsType(points))
                                word.add_TextEquiv(
                                    TextEquivType(Unicode=word_text))

                                if self.parameter[
                                        'textequiv_level'] == 'glyph':
                                    for glyph_no, p in enumerate(
                                            word_positions):
                                        glyph_start = p.global_start
                                        glyph_end = p.global_end

                                        polygon = polygon_from_x0y0x1y1([
                                            glyph_start, 0, glyph_end,
                                            line_image.height
                                        ])
                                        points = points_from_polygon(
                                            coordinates_for_segment(
                                                polygon, None, line_coords))

                                        glyph = GlyphType(
                                            id='%s_glyph%04d' %
                                            (word.id, glyph_no),
                                            Coords=CoordsType(points))

                                        # Add predictions (= TextEquivs)
                                        char_index_start = 1  # Must start with 1, see https://ocr-d.github.io/page#multiple-textequivs
                                        for char_index, char in enumerate(
                                                _sort_chars(p),
                                                start=char_index_start):
                                            glyph.add_TextEquiv(
                                                TextEquivType(
                                                    Unicode=char.char,
                                                    index=char_index,
                                                    conf=char.probability))

                                        word.add_Glyph(glyph)

                                line.add_Word(word)
                                word_no += 1

                            i += word_length

            _page_update_higher_textequiv_levels('line', pcgts)

            # Add metadata about this operation and its runtime parameters:
            metadata = pcgts.get_Metadata()  # ensured by from_file()
            metadata.add_MetadataItem(
                MetadataItemType(
                    type_="processingStep",
                    name=self.ocrd_tool['steps'][0],
                    value=TOOL,
                    Labels=[
                        LabelsType(externalModel="ocrd-tool",
                                   externalId="parameters",
                                   Label=[
                                       LabelType(type_=name,
                                                 value=self.parameter[name])
                                       for name in self.parameter.keys()
                                   ])
                    ]))

            file_id = make_file_id(input_file, self.output_file_grp)
            pcgts.set_pcGtsId(file_id)
            self.workspace.add_file(ID=file_id,
                                    file_grp=self.output_file_grp,
                                    pageId=input_file.pageId,
                                    mimetype=MIMETYPE_PAGE,
                                    local_filename=os.path.join(
                                        self.output_file_grp,
                                        file_id + '.xml'),
                                    content=to_xml(pcgts))
Esempio n. 16
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--files",
                        nargs="+",
                        required=True,
                        default=[],
                        help="List all image files that shall be processed")
    parser.add_argument("--checkpoint",
                        type=str,
                        nargs="+",
                        default=[],
                        help="Path to the checkpoint without file extension")
    parser.add_argument("-j",
                        "--processes",
                        type=int,
                        default=1,
                        help="Number of processes to use")
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help=
        "The batch size during the prediction (number of lines to process in parallel)"
    )
    parser.add_argument("--verbose",
                        action="store_true",
                        help="Print additional information")
    parser.add_argument(
        "--voter",
        type=str,
        default="confidence_voter_default_ctc",
        help=
        "The voting algorithm to use. Possible values: confidence_voter_default_ctc (default), "
        "confidence_voter_fuzzy_ctc, sequence_voter")
    parser.add_argument(
        "--output_dir",
        type=str,
        help=
        "By default the prediction files will be written to the same directory as the given files. "
        "You can use this argument to specify a specific output dir for the prediction files."
    )
    parser.add_argument(
        "--extended_prediction_data",
        action="store_true",
        help=
        "Write: Predicted string, labels; position, probabilities and alternatives of chars to a .pred (protobuf) file"
    )
    parser.add_argument(
        "--extended_prediction_data_format",
        type=str,
        default="json",
        help=
        "Extension format: Either pred or json. Note that json will not print logits."
    )
    parser.add_argument("--no_progress_bars",
                        action="store_true",
                        help="Do not show any progress bars")

    args = parser.parse_args()

    # checks
    if args.extended_prediction_data_format not in ["pred", "json"]:
        raise Exception(
            "Only 'pred' and 'json' are allowed extended prediction data formats"
        )

    # add json as extension, resolve wildcard, expand user, ... and remove .json again
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json")
                       for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]

    # create voter
    voter_params = VoterParams()
    voter_params.type = VoterParams.Type.Value(args.voter.upper())
    voter = voter_from_proto(voter_params)

    # load files
    input_image_files = sorted(glob_all(args.files))

    # skip invalid files, but keep then so that empty predictions are created
    dataset = FileDataSet(input_image_files,
                          skip_invalid=True,
                          remove_invalid=False)

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception(
            "Empty dataset provided. Check your files argument (got {})!".
            format(args.files))

    # predict for all models
    predictor = MultiPredictor(checkpoints=args.checkpoint)
    do_prediction = predictor.predict_dataset(
        dataset,
        batch_size=args.batch_size,
        processes=args.processes,
        progress_bar=not args.no_progress_bars)

    # output the voted results to the appropriate files
    for (result, sample), filepath in zip(do_prediction, input_image_files):
        for i, p in enumerate(result):
            p.prediction.id = "fold_{}".format(i)

        # vote the results (if only one model is given, this will just return the sentences)
        prediction = voter.vote_prediction_result(result)
        prediction.id = "voted"
        sentence = prediction.sentence
        if args.verbose:
            print("{}: '{}'".format(sample['id'], sentence))

        output_dir = args.output_dir if args.output_dir else os.path.dirname(
            filepath)

        with codecs.open(os.path.join(output_dir, sample['id'] + ".pred.txt"),
                         'w', 'utf-8') as f:
            f.write(sentence)

        if args.extended_prediction_data:
            ps = Predictions()
            ps.line_path = filepath
            ps.predictions.extend([prediction] +
                                  [r.prediction for r in result])
            if args.extended_prediction_data_format == "pred":
                with open(os.path.join(output_dir, sample['id'] + ".pred"),
                          'wb') as f:
                    f.write(ps.SerializeToString())
            elif args.extended_prediction_data_format == "json":
                with open(os.path.join(output_dir, sample['id'] + ".json"),
                          'w') as f:
                    # remove logits
                    for prediction in ps.predictions:
                        prediction.logits.rows = 0
                        prediction.logits.cols = 0
                        prediction.logits.data[:] = []

                    f.write(
                        MessageToJson(ps, including_default_value_fields=True))
            else:
                raise Exception("Unknown prediction format.")

    print("All files written")
Esempio n. 17
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--eval_imgs",
                        type=str,
                        nargs="+",
                        required=True,
                        help="The evaluation files")
    parser.add_argument("--checkpoint",
                        type=str,
                        nargs="+",
                        default=[],
                        help="Path to the checkpoint without file extension")
    parser.add_argument("-j",
                        "--processes",
                        type=int,
                        default=1,
                        help="Number of processes to use")
    parser.add_argument("--verbose",
                        action="store_true",
                        help="Print additional information")
    parser.add_argument(
        "--voter",
        type=str,
        nargs="+",
        default=[
            "sequence_voter", "confidence_voter_default_ctc",
            "confidence_voter_fuzzy_ctc"
        ],
        help=
        "The voting algorithm to use. Possible values: confidence_voter_default_ctc (default), "
        "confidence_voter_fuzzy_ctc, sequence_voter")
    parser.add_argument("--batch_size",
                        type=int,
                        default=10,
                        help="The batch size for prediction")
    parser.add_argument("--dump",
                        type=str,
                        help="Dump the output as serialized pickle object")
    parser.add_argument(
        "--no_skip_invalid_gt",
        action="store_true",
        help="Do no skip invalid gt, instead raise an exception.")

    args = parser.parse_args()

    # allow user to specify json file for model definition, but remove the file extension
    # for further processing
    args.checkpoint = [(cp[:-5] if cp.endswith(".json") else cp)
                       for cp in args.checkpoint]

    # load files
    gt_images = sorted(glob_all(args.eval_imgs))
    gt_txts = [
        split_all_ext(path)[0] + ".gt.txt"
        for path in sorted(glob_all(args.eval_imgs))
    ]

    dataset = FileDataSet(images=gt_images,
                          texts=gt_txts,
                          skip_invalid=not args.no_skip_invalid_gt)

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception(
            "Empty dataset provided. Check your files argument (got {})!".
            format(args.files))

    # predict for all models
    n_models = len(args.checkpoint)
    predictor = MultiPredictor(checkpoints=args.checkpoint,
                               batch_size=args.batch_size,
                               processes=args.processes)
    do_prediction = predictor.predict_dataset(dataset, progress_bar=True)

    voters = []
    all_voter_sentences = []
    all_prediction_sentences = [[] for _ in range(n_models)]

    for voter in args.voter:
        # create voter
        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value(voter.upper())
        voters.append(voter_from_proto(voter_params))
        all_voter_sentences.append([])

    for prediction, sample in do_prediction:
        for sent, p in zip(all_prediction_sentences, prediction):
            sent.append(p.sentence)

        # vote results
        for voter, voter_sentences in zip(voters, all_voter_sentences):
            voter_sentences.append(
                voter.vote_prediction_result(prediction).sentence)

    # evaluation
    text_preproc = text_processor_from_proto(
        predictor.predictors[0].model_params.text_preprocessor)
    evaluator = Evaluator(text_preprocessor=text_preproc)
    evaluator.preload_gt(gt_dataset=dataset, progress_bar=True)

    def single_evaluation(predicted_sentences):
        if len(predicted_sentences) != len(dataset):
            raise Exception(
                "Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did "
                "not succeed".format(len(dataset), len(predicted_sentences)))

        pred_data_set = RawDataSet(texts=predicted_sentences)

        r = evaluator.run(pred_dataset=pred_data_set,
                          progress_bar=True,
                          processes=args.processes)

        return r

    full_evaluation = {}
    for id, data in [
        (str(i), sent) for i, sent in enumerate(all_prediction_sentences)
    ] + list(zip(args.voter, all_voter_sentences)):
        full_evaluation[id] = {"eval": single_evaluation(data), "data": data}

    if args.verbose:
        print(full_evaluation)

    if args.dump:
        import pickle
        with open(args.dump, 'wb') as f:
            pickle.dump(
                {
                    "full": full_evaluation,
                    "gt_txts": gt_txts,
                    "gt": dataset.text_samples()
                }, f)
Esempio n. 18
0
def run(args):
    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        import json
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    # checks
    if args.extended_prediction_data_format not in ["pred", "json"]:
        raise Exception("Only 'pred' and 'json' are allowed extended prediction data formats")

    # add json as extension, resolve wildcard, expand user, ... and remove .json again
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json") for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]

    # create voter
    voter_params = VoterParams()
    voter_params.type = VoterParams.Type.Value(args.voter.upper())
    voter = voter_from_proto(voter_params)

    # load files
    input_image_files = glob_all(args.files)
    if args.text_files:
        args.text_files = glob_all(args.text_files)

    # skip invalid files and remove them, there wont be predictions of invalid files
    dataset = create_dataset(
        args.dataset,
        DataSetMode.PREDICT,
        input_image_files,
        args.text_files,
        skip_invalid=True,
        remove_invalid=True,
        args={'text_index': args.pagexml_text_index},
    )

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception("Empty dataset provided. Check your files argument (got {})!".format(args.files))

    # predict for all models
    predictor = MultiPredictor(checkpoints=args.checkpoint, batch_size=args.batch_size, processes=args.processes)
    do_prediction = predictor.predict_dataset(dataset, progress_bar=not args.no_progress_bars)

    avg_sentence_confidence = 0
    n_predictions = 0

    # output the voted results to the appropriate files
    for result, sample in do_prediction:
        n_predictions += 1
        for i, p in enumerate(result):
            p.prediction.id = "fold_{}".format(i)

        # vote the results (if only one model is given, this will just return the sentences)
        prediction = voter.vote_prediction_result(result)
        prediction.id = "voted"
        sentence = prediction.sentence
        avg_sentence_confidence += prediction.avg_char_probability
        if args.verbose:
            lr = "\u202A\u202B"
            print("{}: '{}{}{}'".format(sample['id'], lr[get_base_level(sentence)], sentence, "\u202C" ))

        output_dir = args.output_dir

        dataset.store_text(sentence, sample, output_dir=output_dir, extension=".pred.txt")

        if args.extended_prediction_data:
            ps = Predictions()
            ps.line_path = sample['image_path'] if 'image_path' in sample else sample['id']
            ps.predictions.extend([prediction] + [r.prediction for r in result])
            output_dir = output_dir if output_dir else os.path.dirname(ps.line_path)
            if not os.path.exists(output_dir):
                os.mkdir(output_dir)

            if args.extended_prediction_data_format == "pred":
                with open(os.path.join(output_dir, sample['id'] + ".pred"), 'wb') as f:
                    f.write(ps.SerializeToString())
            elif args.extended_prediction_data_format == "json":
                with open(os.path.join(output_dir, sample['id'] + ".json"), 'w') as f:
                    # remove logits
                    for prediction in ps.predictions:
                        prediction.logits.rows = 0
                        prediction.logits.cols = 0
                        prediction.logits.data[:] = []

                    f.write(MessageToJson(ps, including_default_value_fields=True))
            else:
                raise Exception("Unknown prediction format.")

    print("Average sentence confidence: {:.2%}".format(avg_sentence_confidence / n_predictions))

    dataset.store()
    print("All files written")