Ejemplo n.º 1
0
 def test_eval_files_with_different_sources(self):
     run_predict(
         predict_args(data=FileDataParams(
             pred_extension=".ext-pred.txt",
             images=sorted(
                 glob_all([
                     os.path.join(this_dir, "data", "uw3_50lines", "test",
                                  "*.png")
                 ])),
         )))
     r = run_eval(
         eval_args(
             gt_data=FileDataParams(texts=sorted(
                 glob_all([
                     os.path.join(this_dir, "data", "uw3_50lines", "test",
                                  "*.gt.txt")
                 ]))),
             pred_data=FileDataParams(texts=sorted(
                 glob_all([
                     os.path.join(
                         this_dir,
                         "data",
                         "uw3_50lines",
                         "test",
                         "*.ext-pred.txt",
                     )
                 ]))),
         ))
     self.assertLess(r["avg_ler"],
                     0.0009,
                     msg="Current best model yields about 0.09% CER")
Ejemplo n.º 2
0
    def prepare_for_mode(self, mode: PipelineMode):
        logger.info("Resolving input files")
        input_image_files = sorted(glob_all(self.images))

        if not self.texts:
            gt_txt_files = [split_all_ext(f)[0] + self.gt_extension for f in input_image_files]
        else:
            gt_txt_files = sorted(glob_all(self.texts))
            if mode in INPUT_PROCESSOR:
                input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files)
                for img, gt in zip(input_image_files, gt_txt_files):
                    if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                        raise Exception(f"Expected identical basenames of file: {img} and {gt}")
            else:
                input_image_files = None

        if mode in {PipelineMode.TRAINING, PipelineMode.EVALUATION}:
            if len(set(gt_txt_files)) != len(gt_txt_files):
                logger.warning(
                    "Some ground truth text files occur more than once in the data set "
                    "(ignore this warning, if this was intended)."
                )
            if len(set(input_image_files)) != len(input_image_files):
                logger.warning(
                    "Some images occur more than once in the data set. " "This warning should usually not be ignored."
                )

        self.images = input_image_files
        self.texts = gt_txt_files
Ejemplo n.º 3
0
def create_train_dataset(args, dataset_args=None):
    gt_extension = args.gt_extension if args.gt_extension is not None else DataSetType.gt_extension(args.dataset)

    # Training dataset
    print("Resolving input files")
    input_image_files = sorted(glob_all(args.files))
    if not args.text_files:
        if gt_extension:
            gt_txt_files = [split_all_ext(f)[0] + gt_extension for f in input_image_files]
        else:
            gt_txt_files = [None] * len(input_image_files)
    else:
        gt_txt_files = sorted(glob_all(args.text_files))
        input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files)
        for img, gt in zip(input_image_files, gt_txt_files):
            if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                raise Exception("Expected identical basenames of file: {} and {}".format(img, gt))

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception("Some image are occurring more than once in the data set.")

    dataset = create_dataset(
        args.dataset,
        DataSetMode.TRAIN,
        images=input_image_files,
        texts=gt_txt_files,
        skip_invalid=not args.no_skip_invalid_gt,
        args=dataset_args if dataset_args else {},
    )
    print("Found {} files in the dataset".format(len(dataset)))
    return dataset
Ejemplo n.º 4
0
def uw3_trainer_params(with_validation=False,
                       with_split=False,
                       preload=True,
                       debug=False):
    p = CalamariTestScenario.default_trainer_params()
    p.scenario.debug_graph_construction = debug
    p.force_eager = debug

    train = FileDataParams(
        images=glob_all(
            [os.path.join(this_dir, "data", "uw3_50lines", "train", "*.png")]),
        preload=preload,
    )
    if with_split:
        p.gen = CalamariSplitTrainerPipelineParams(validation_split_ratio=0.2,
                                                   train=train)
    elif with_validation:
        p.gen.val.images = glob_all(
            [os.path.join(this_dir, "data", "uw3_50lines", "test", "*.png")])
        p.gen.val.preload = preload
        p.gen.train = train
        p.gen.__post_init__()
    else:
        p.gen = CalamariTrainOnlyPipelineParams(train=train)

    p.gen.setup.val.batch_size = 1
    p.gen.setup.val.num_processes = 1
    p.gen.setup.train.batch_size = 1
    p.gen.setup.train.num_processes = 1
    post_init(p)
    return p
Ejemplo n.º 5
0
    def prepare_for_mode(self, mode: PipelineMode) -> 'PipelineParams':
        from calamari_ocr.ocr.dataset.datareader.factory import DataReaderFactory
        assert (self.type is not None)
        params_out = deepcopy(self)
        # Training dataset
        logger.info("Resolving input files")
        if isinstance(self.type, str):
            try:
                self.type = DataSetType.from_string(self.type)
            except ValueError:
                # Not a valid type, must be custom
                if self.type not in DataReaderFactory.CUSTOM_READERS:
                    raise KeyError(
                        f"DataSetType {self.type} is neither a standard DataSetType or preset as custom "
                        f"reader ({list(DataReaderFactory.CUSTOM_READERS.keys())})"
                    )
        if not isinstance(self.type, str) and self.type not in {
                DataSetType.RAW, DataSetType.GENERATED_LINE
        }:
            input_image_files = sorted(glob_all(
                self.files)) if self.files else None

            if not self.text_files:
                if self.gt_extension:
                    gt_txt_files = [
                        split_all_ext(f)[0] + self.gt_extension
                        for f in input_image_files
                    ]
                else:
                    gt_txt_files = None
            else:
                gt_txt_files = sorted(glob_all(self.text_files))
                if mode in INPUT_PROCESSOR:
                    input_image_files, gt_txt_files = keep_files_with_same_file_name(
                        input_image_files, gt_txt_files)
                    for img, gt in zip(input_image_files, gt_txt_files):
                        if split_all_ext(
                                os.path.basename(img))[0] != split_all_ext(
                                    os.path.basename(gt))[0]:
                            raise Exception(
                                "Expected identical basenames of file: {} and {}"
                                .format(img, gt))
                else:
                    input_image_files = None

            if mode in {PipelineMode.Training, PipelineMode.Evaluation}:
                if len(set(gt_txt_files)) != len(gt_txt_files):
                    logger.warning(
                        "Some ground truth text files occur more than once in the data set "
                        "(ignore this warning, if this was intended).")
                if len(set(input_image_files)) != len(input_image_files):
                    logger.warning(
                        "Some images occur more than once in the data set. "
                        "This warning should usually not be ignored.")

            params_out.files = input_image_files
            params_out.text_files = gt_txt_files
        return params_out
Ejemplo n.º 6
0
def data_reader_from_params(mode: PipelineMode,
                            params: PipelineParams) -> DataReader:
    assert (params.type is not None)
    from calamari_ocr.ocr.dataset.dataset_factory import create_data_reader
    # Training dataset
    logger.info("Resolving input files")
    if params.type not in {DataSetType.RAW, DataSetType.GENERATED_LINE}:
        input_image_files = sorted(glob_all(
            params.files)) if params.files else None

        if not params.text_files:
            if params.gt_extension:
                gt_txt_files = [
                    split_all_ext(f)[0] + params.gt_extension
                    for f in input_image_files
                ]
            else:
                gt_txt_files = None
        else:
            gt_txt_files = sorted(glob_all(params.text_files))
            if mode in INPUT_PROCESSOR:
                input_image_files, gt_txt_files = keep_files_with_same_file_name(
                    input_image_files, gt_txt_files)
                for img, gt in zip(input_image_files, gt_txt_files):
                    if split_all_ext(
                            os.path.basename(img))[0] != split_all_ext(
                                os.path.basename(gt))[0]:
                        raise Exception(
                            "Expected identical basenames of file: {} and {}".
                            format(img, gt))
            else:
                input_image_files = None

        if mode in {PipelineMode.Training, PipelineMode.Evaluation}:
            if len(set(gt_txt_files)) != len(gt_txt_files):
                logger.warning(
                    "Some ground truth text files occur more than once in the data set "
                    "(ignore this warning, if this was intended).")
            if len(set(input_image_files)) != len(input_image_files):
                logger.warning(
                    "Some images occur more than once in the data set. "
                    "This warning should usually not be ignored.")
    else:
        input_image_files = params.files
        gt_txt_files = params.text_files

    dataset = create_data_reader(
        params.type,
        mode,
        images=input_image_files,
        texts=gt_txt_files,
        skip_invalid=params.skip_invalid,
        args=params.data_reader_args
        if params.data_reader_args else FileDataReaderArgs(),
    )
    logger.info(f"Found {len(dataset)} files in the dataset")
    return dataset
Ejemplo n.º 7
0
def create_test_dataset(
    cfg: CfgNode,
    dataset_args=None
) -> Union[List[Union[RawDataSet, FileDataSet, AbbyyDataSet, PageXMLDataset,
                      Hdf5DataSet, ExtendedPredictionDataSet,
                      GeneratedLineDataset]], None]:
    if cfg.DATASET.VALID.TEXT_FILES:
        assert len(cfg.DATASET.VALID.PATH) == len(cfg.DATASET.VALID.TEXT_FILES)

    if cfg.DATASET.VALID.PATH:
        validation_dataset_list = []
        print("Resolving validation files")
        for i, valid_path in enumerate(cfg.DATASET.VALID.PATH):
            validation_image_files = glob_all(valid_path)
            dataregistry.register(
                i, os.path.basename(os.path.dirname(valid_path)),
                len(validation_image_files))

            if not cfg.DATASET.VALID.TEXT_FILES:
                val_txt_files = [
                    split_all_ext(f)[0] + cfg.DATASET.VALID.GT_EXTENSION
                    for f in validation_image_files
                ]
            else:
                val_txt_files = sorted(
                    glob_all(cfg.DATASET.VALID.TEXT_FILES[i]))
                validation_image_files, val_txt_files = keep_files_with_same_file_name(
                    validation_image_files, val_txt_files)
                for img, gt in zip(validation_image_files, val_txt_files):
                    if split_all_ext(
                            os.path.basename(img))[0] != split_all_ext(
                                os.path.basename(gt))[0]:
                        raise Exception(
                            "Expected identical basenames of validation file: {} and {}"
                            .format(img, gt))

            if len(set(val_txt_files)) != len(val_txt_files):
                raise Exception(
                    "Some validation images are occurring more than once in the data set."
                )

            validation_dataset = create_dataset(
                cfg.DATASET.VALID.TYPE,
                DataSetMode.TRAIN,
                images=validation_image_files,
                texts=val_txt_files,
                skip_invalid=not cfg.DATALOADER.NO_SKIP_INVALID_GT,
                args=dataset_args,
            )
            print("Found {} files in the validation dataset".format(
                len(validation_dataset)))
            validation_dataset_list.append(validation_dataset)
    else:
        validation_dataset_list = None

    return validation_dataset_list
Ejemplo n.º 8
0
 def prepare_for_mode(self, mode: PipelineMode):
     self.images = sorted(glob_all(self.images))
     self.xml_files = sorted(self.xml_files)
     if not self.xml_files:
         self.xml_files = [
             split_all_ext(f)[0] + self.gt_extension for f in self.images
         ]
     if not self.images:
         self.xml_files = sorted(glob_all(self.xml_files))
         self.images = [None] * len(self.xml_files)
Ejemplo n.º 9
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--checkpoint",
                        type=str,
                        required=True,
                        help="The checkpoint used to resume")
    parser.add_argument("--validation",
                        type=str,
                        nargs="+",
                        help="Validation line files used for early stopping")
    parser.add_argument("files",
                        type=str,
                        nargs="+",
                        help="The files to use for training")

    args = parser.parse_args()

    # Train dataset
    input_image_files = glob_all(args.files)
    gt_txt_files = [split_all_ext(f)[0] + ".gt.txt" for f in input_image_files]

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception(
            "Some image are occurring more than once in the data set.")

    dataset = FileDataSet(input_image_files, gt_txt_files)

    print("Found {} files in the dataset".format(len(dataset)))

    # Validation dataset
    if args.validation:
        validation_image_files = glob_all(args.validation)
        val_txt_files = [
            split_all_ext(f)[0] + ".gt.txt" for f in validation_image_files
        ]

        if len(set(val_txt_files)) != len(val_txt_files):
            raise Exception(
                "Some validation images are occurring more than once in the data set."
            )

        validation_dataset = FileDataSet(validation_image_files, val_txt_files)
        print("Found {} files in the validation dataset".format(
            len(validation_dataset)))
    else:
        validation_dataset = None

    with open(args.checkpoint + '.json', 'r') as f:
        checkpoint_params = json_format.Parse(f.read(), CheckpointParams())

        trainer = Trainer(checkpoint_params,
                          dataset,
                          validation_dataset=validation_dataset,
                          restore=args.checkpoint)
        trainer.train(progress_bar=True)
Ejemplo n.º 10
0
 def test_eval_list_files(self):
     run_predict(
         predict_args(data=FileDataParams(images=sorted(
             glob_all([
                 os.path.join(this_dir, "data", "uw3_50lines", "test.files")
             ])))))
     r = run_eval(
         eval_args(gt_data=FileDataParams(texts=sorted(
             glob_all([
                 os.path.join(this_dir, "data", "uw3_50lines",
                              "test.gt.files")
             ])))))
     self.assertLess(r["avg_ler"],
                     0.0009,
                     msg="Current best model yields about 0.09% CER")
Ejemplo n.º 11
0
 def to_prediction(self):
     self.files = sorted(glob_all(self.files))
     pred = deepcopy(self)
     pred.files = [
         split_all_ext(f)[0] + self.pred_extension for f in self.files
     ]
     return pred
Ejemplo n.º 12
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--files",
                        type=str,
                        default=[],
                        nargs="+",
                        required=True,
                        help="Protobuf files to convert")
    parser.add_argument("--logits",
                        action="store_true",
                        help="Do write logits")
    args = parser.parse_args()

    files = glob_all(args.files)
    for file in tqdm(files, desc="Converting"):
        predictions = Predictions()
        with open(file, 'rb') as f:
            predictions.ParseFromString(f.read())

        if not args.logits:
            for prediction in predictions.predictions:
                prediction.logits.rows = 0
                prediction.logits.cols = 0
                prediction.logits.data[:] = []

        out_json_path = split_all_ext(file)[0] + ".json"
        with open(out_json_path, 'w') as f:
            f.write(
                MessageToJson(predictions,
                              including_default_value_fields=True))
Ejemplo n.º 13
0
    def test_prediction_extended_and_positions(self):
        # With actual model to evaluate correct positions
        args = predict_args()
        args.checkpoint = [
            os.path.join(this_dir, "models",
                         f"version{SavedCalamariModel.VERSION}", "0.ckpt")
        ]
        args.extended_prediction_data = True
        run(args)
        jsons = [
            os.path.join(this_dir, "data", "uw3_50lines", "test", "*.json")
        ]
        run_compute_avg_pred(ExtendedPredictionDataParams(files=jsons))

        def assert_pos_in_interval(p, start, end):
            self.assertGreaterEqual(p.global_start, start)
            self.assertGreaterEqual(p.global_end, start)
            self.assertLessEqual(p.global_start, end)
            self.assertLessEqual(p.global_end, end)

        with open(sorted(glob_all(jsons[0]))[0]) as f:
            first_pred: Predictions = Predictions.from_json(f.read())
            for p in first_pred.predictions:
                # Check for correct prediction string (models is trained!)
                self.assertEqual(
                    p.sentence,
                    "The problem, simplified for our purposes, is set up as")
                # Check for correct character positions
                assert_pos_in_interval(p.positions[0], 0, 24)  # T
                assert_pos_in_interval(p.positions[1], 24, 43)  # h
                assert_pos_in_interval(p.positions[2], 45, 63)  # e
                # ...
                assert_pos_in_interval(p.positions[-2], 1062, 1081)  # a
                assert_pos_in_interval(p.positions[-1], 1084, 1099)  # s
Ejemplo n.º 14
0
 def __init__(self):
     self.files = glob_all([os.path.join(this_dir, "data", "uw3_50lines", "train", "*.png")])
     self.seed = 24
     self.backend = "tensorflow"
     self.network = "cnn=40:3x3,pool=2x2,cnn=60:3x3,pool=2x2,lstm=200,dropout=0.5"
     self.line_height = 48
     self.pad = 16
     self.num_threads = 1
     self.display = 1
     self.batch_size = 1
     self.checkpoint_frequency = 1000
     self.max_iters = 1000
     self.stats_size = 100
     self.no_skip_invalid_gt = False
     self.no_progress_bars = True
     self.output_dir = os.path.join(this_dir, "test_models")
     self.output_model_prefix = "uw3_50lines"
     self.bidi_dir = None
     self.weights = None
     self.whitelist_files = []
     self.whitelist = []
     self.gradient_clipping_mode = "AUTO"
     self.gradient_clipping_const = 0
     self.validation = None
     self.early_stopping_frequency = -1
     self.early_stopping_nbest = 10
     self.early_stopping_best_model_prefix = "uw3_50lines_best"
     self.early_stopping_best_model_output_dir = self.output_dir
     self.n_augmentations = 0
     self.fuzzy_ctc_library_path = ""
     self.num_inter_threads = 0
     self.num_intra_threads = 0
     self.text_regularization = ["extended"]
     self.text_normalization = "NFC"
Ejemplo n.º 15
0
    def __init__(self, settings: AlgorithmPredictorSettings):
        super().__init__(settings)
        # ctc_decoder_params = deepcopy(settings.params.ctcDecoder.params)
        # lnp = LyricsNormalizationProcessor(LyricsNormalizationParams(LyricsNormalization.ONE_STRING))
        # if len(ctc_decoder_params.dictionary) > 0:
        #     ctc_decoder_params.dictionary[:] = [lnp.apply(word) for word in ctc_decoder_params.dictionary]
        # else:
        #     with open(os.path.join(BASE_DIR, 'internal_storage', 'resources', 'hyphen_dictionary.txt')) as f:
        #         # TODO: dataset params in settings, that we can create the correct normalization params
        #         ctc_decoder_params.dictionary[:] = [lnp.apply(line.split()[0]) for line in f.readlines()]

        # self.predictor = MultiPredictor(glob_all([s + '/text_best*.ckpt.json' for s in params.checkpoints]))
        voter_params = VoterParams()
        voter_params.type = VoterParams.type.ConfidenceVoterDefaultCTC
        self.predictor = MultiPredictor.from_paths(
            checkpoints=glob_all([settings.model.local_file('text.ckpt.json')
                                  ]),
            voter_params=voter_params,
            predictor_params=PredictorParams(
                silent=True,
                progress_bar=True,
                pipeline=DataPipelineParams(batch_size=1,
                                            mode=PipelineMode("prediction"))))
        # self.height = self.predictor.predictors[0].network_params.features
        self.voter = voter_from_params(voter_params)
        self.dict_corrector = None

        if settings.params.useDictionaryCorrection:
            self.dict_corrector = DictionaryCorrector()
Ejemplo n.º 16
0
    def __init__(self, n_folds, source_files, output_dir):
        """ Prepare cross fold training

        This class creates folds out of the given source files.
        The individual splits are the optionally written to the `output_dir` in a json format.

        The file with index i will be assigned to fold i % n_folds (not randomly!)

        Parameters
        ----------
        n_folds : int
            the number of folds to create
        source_files : str
            the source file names
        output_dir : str
            where to store the folds
        """
        self.n_folds = n_folds
        self.inputs = glob_all(source_files)
        self.output_dir = os.path.abspath(output_dir)

        if len(self.inputs) == 0:
            raise Exception("No files found at '{}'".format(source_files))

        if self.n_folds <= 1:
            raise Exception("At least two folds are required")

        # fill single fold files
        self.folds = [[] for _ in range(self.n_folds)]
        for i, input in enumerate(self.inputs):
            self.folds[i % n_folds].append(input)
Ejemplo n.º 17
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--files", nargs="+", required=True,
                        help="The image files to predict with its gt and pred")
    parser.add_argument("--html_output", type=str, required=True,
                        help="Where to write the html file")
    parser.add_argument("--open", action="store_true",
                        help="Automatically open the file")

    args = parser.parse_args()
    img_files = sorted(glob_all(args.files))
    gt_files = [split_all_ext(f)[0] + ".gt.txt" for f in img_files]
    pred_files = [split_all_ext(f)[0] + ".pred.txt" for f in img_files]

    with open(args.html_output, 'w') as html:
        html.write("""
                   <!DOCTYPE html>
                   <html lang="en">
                   <head>
                       <meta charset="utf-8"/>
                   </head>
                   <body>
                   <ul>""")

        for img, gt, pred in zip(img_files, gt_files, pred_files):
            html.write("<li><p><img src=\"file://{}\"></p><p>{}</p><p>{}</p>\n".format(
                img.replace('\\', '/').replace('/', '\\\\'), open(gt).read(), open(pred).read()
            ))

        html.write("</ul></body></html>")

    if args.open:
        webbrowser.open(args.html_output)
Ejemplo n.º 18
0
    def __init__(self, n_folds, source_files, output_dir):
        """ Prepare cross fold training

        This class creates folds out of the given source files.
        The individual splits are the optionally written to the `output_dir` in a json format.

        The file with index i will be assigned to fold i % n_folds (not randomly!)

        Parameters
        ----------
        n_folds : int
            the number of folds to create
        source_files : str
            the source file names
        output_dir : str
            where to store the folds
        """
        self.n_folds = n_folds
        self.inputs = sorted(glob_all(source_files))
        self.output_dir = os.path.abspath(output_dir)

        if len(self.inputs) == 0:
            raise Exception("No files found at '{}'".format(source_files))

        if self.n_folds <= 1:
            raise Exception("At least two folds are required")

        # fill single fold files
        self.folds = [[] for _ in range(self.n_folds)]
        for i, input in enumerate(self.inputs):
            self.folds[i % n_folds].append(input)
Ejemplo n.º 19
0
 def test_prediction_files_with_different_extension(self):
     run_predict(
         predict_args(data=FileDataParams(
             pred_extension='.ext-pred.txt',
             images=sorted(
                 glob_all([
                     os.path.join(this_dir, "data", "uw3_50lines", "test",
                                  "*.png")
                 ])))))
     run_eval(
         eval_args(gt_data=FileDataParams(
             pred_extension='.ext-pred.txt',
             texts=sorted(
                 glob_all([
                     os.path.join(this_dir, "data", "uw3_50lines", "test",
                                  "*.gt.txt")
                 ])))))
Ejemplo n.º 20
0
    def prepare_for_mode(self, mode: PipelineMode):
        self.images = sorted(glob_all(self.images))
        self.xml_files = sorted(glob_all(self.xml_files))
        if not self.xml_files:
            self.xml_files = [split_all_ext(f)[0] + self.gt_extension for f in self.images]
        if not self.images:
            self.images = [None] * len(self.xml_files)

        if len(self.images) != len(self.xml_files):
            raise ValueError(f"Different number of image and xml files, {len(self.images)} != {len(self.xml_files)}")
        for img_path, xml_path in zip(self.images, self.xml_files):
            if img_path and xml_path:
                img_bn, xml_bn = split_all_ext(img_path)[0], split_all_ext(xml_path)[0]
                if img_bn != xml_bn:
                    logger.warning(
                        f"Filenames are not matching, got base names \n  image: {img_bn}\n  xml:   {xml_bn}\n."
                    )
Ejemplo n.º 21
0
 def __init__(self, settings: AlgorithmPredictorSettings):
     super().__init__(settings)
     self.predictor = MultiPredictor(glob_all([s + '/omr_best*.ckpt.json' for s in [settings.model.path]]),
                                     ctc_decoder_params=settings.params.ctcDecoder.params)
     self.height = self.predictor.predictors[0].network_params.features
     voter_params = VoterParams()
     voter_params.type = VoterParams.CONFIDENCE_VOTER_DEFAULT_CTC
     self.voter = voter_from_proto(voter_params)
Ejemplo n.º 22
0
 def __init__(self):
     self.dataset = DataSetType.FILE
     self.gt_extension = DataSetType.gt_extension(self.dataset)
     self.files = glob_all(
         [os.path.join(this_dir, "data", "uw3_50lines", "train", "*.png")])
     self.seed = 24
     self.backend = "tensorflow"
     self.network = "cnn=40:3x3,pool=2x2,cnn=60:3x3,pool=2x2,lstm=200,dropout=0.5"
     self.line_height = 48
     self.pad = 16
     self.num_threads = 1
     self.display = 1
     self.batch_size = 1
     self.checkpoint_frequency = 1000
     self.epochs = 1
     self.samples_per_epoch = 8
     self.stats_size = 100
     self.no_skip_invalid_gt = False
     self.no_progress_bars = True
     self.output_dir = None
     self.output_model_prefix = "uw3_50lines"
     self.bidi_dir = None
     self.weights = None
     self.ema_weights = False
     self.whitelist_files = []
     self.whitelist = []
     self.gradient_clipping_norm = 5
     self.validation = None
     self.validation_dataset = DataSetType.FILE
     self.validation_extension = None
     self.validation_split_ratio = None
     self.early_stopping_frequency = -1
     self.early_stopping_nbest = 10
     self.early_stopping_at_accuracy = 0.99
     self.early_stopping_best_model_prefix = "uw3_50lines_best"
     self.early_stopping_best_model_output_dir = self.output_dir
     self.n_augmentations = 0
     self.num_inter_threads = 0
     self.num_intra_threads = 0
     self.text_regularization = ["extended"]
     self.text_normalization = "NFC"
     self.text_generator_params = None
     self.line_generator_params = None
     self.pagexml_text_index = 0
     self.text_files = None
     self.only_train_on_augmented = False
     self.data_preprocessing = [p.name for p in default_image_processors()]
     self.shuffle_buffer_size = 1000
     self.keep_loaded_codec = False
     self.train_data_on_the_fly = False
     self.validation_data_on_the_fly = False
     self.no_auto_compute_codec = False
     self.dataset_pad = 0
     self.debug = False
     self.train_verbose = True
     self.use_train_as_val = False
     self.ensemble = -1
     self.masking_mode = 1
Ejemplo n.º 23
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--files", nargs="+", required=True,
                        help="List of all image files with corresponding gt.txt files")
    parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--line_height", type=int, default=48,
                        help="The line height")
    parser.add_argument("--pad", type=int, default=16,
                        help="Padding (left right) of the line")

    args = parser.parse_args()

    print("Resolving files")
    image_files = glob_all(args.files)
    gt_files = [split_all_ext(p)[0] + ".gt.txt" for p in image_files]

    ds = create_dataset(
        args.dataset,
        DataSetMode.TRAIN,
        images=image_files, texts=gt_files, non_existing_as_empty=True)

    print("Loading {} files".format(len(image_files)))
    ds.load_samples(processes=1, progress_bar=True)
    images, texts = ds.train_samples(skip_empty=True)
    statistics = {
        "n_lines": len(images),
        "chars": [len(c) for c in texts],
        "widths": [img.shape[1] / img.shape[0] * args.line_height + 2 * args.pad for img in images
                   if img is not None and img.shape[0] > 0 and img.shape[1] > 0],
        "total_line_width": 0,
        "char_counts": {},
    }

    for image, text in zip(images, texts):
        for c in text:
            if c in statistics["char_counts"]:
                statistics["char_counts"][c] += 1
            else:
                statistics["char_counts"][c] = 1

    statistics["av_line_width"] = np.average(statistics["widths"])
    statistics["max_line_width"] = np.max(statistics["widths"])
    statistics["min_line_width"] = np.min(statistics["widths"])
    statistics["total_line_width"] = np.sum(statistics["widths"])

    statistics["av_chars"] = np.average(statistics["chars"])
    statistics["max_chars"] = np.max(statistics["chars"])
    statistics["min_chars"] = np.min(statistics["chars"])
    statistics["total_chars"] = np.sum(statistics["chars"])

    statistics["av_px_per_char"] = statistics["av_line_width"] / statistics["av_chars"]
    statistics["codec_size"] = len(statistics["char_counts"])

    del statistics["chars"]
    del statistics["widths"]


    print(statistics)
Ejemplo n.º 24
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--files", nargs="+", type=str, required=True,
                        help="The image files to copy")
    parser.add_argument("--target_dir", type=str, required=True,
                        help="")
    parser.add_argument("--index_files", action="store_true")
    parser.add_argument("--convert_images", type=str,
                        help="Convert the image to a given type (by default use original format). E. g. jpg, png, tif, ...")
    parser.add_argument("--gt_ext", type=str, default=".gt.txt")
    parser.add_argument("--index_ext", type=str, default=".index")

    args = parser.parse_args()

    if args.convert_images and not args.convert_images.startswith("."):
        args.convert_images = "." + args.convert_images

    args.target_dir = os.path.expanduser(args.target_dir)

    print("Resolving files")
    image_files = glob_all(args.files)
    gt_files = [split_all_ext(p)[0] + ".gt.txt" for p in image_files]

    if len(image_files) == 0:
        raise Exception("No files found")

    if not os.path.isdir(args.target_dir):
        os.makedirs(args.target_dir)

    for i, (img, gt) in tqdm(enumerate(zip(image_files, gt_files)), total=len(gt_files), desc="Copying"):
        if not os.path.exists(img) or not os.path.exists(gt):
            # skip non existing examples
            continue

        # img with optional convert
        try:
            ext = split_all_ext(img)[1]
            target_ext = args.convert_images if args.convert_images else ext
            target_name = os.path.join(args.target_dir, "{:08}{}".format(i, target_ext))
            if ext == target_ext:
                shutil.copyfile(img, target_name)
            else:
                data = skimage_io.imread(img)
                skimage_io.imsave(target_name, data)

        except:
            continue

        # gt txt
        target_name = os.path.join(args.target_dir, "{:08}{}".format(i, args.gt_ext))
        shutil.copyfile(gt, target_name)

        if args.index_files:
            target_name = os.path.join(args.target_dir, "{:08}{}".format(i, args.index_ext))
            with open(target_name, "w") as f:
                f.write(str(i))
Ejemplo n.º 25
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoints', nargs='+', type=str, required=True)
    parser.add_argument('--dry_run', action='store_true')

    args = parser.parse_args()

    for ckpt in tqdm(glob_all(args.checkpoints)):
        ckpt = os.path.splitext(ckpt)[0]
        SavedCalamariModel(ckpt, dry_run=args.dry_run)
Ejemplo n.º 26
0
def setup_trainer_params(preload=True, debug=False):
    p = CalamariTestEnsembleScenario.default_trainer_params()
    p.force_eager = debug

    p.gen.train = FileDataParams(
        images=glob_all([os.path.join(this_dir, "data", "uw3_50lines", "train", "*.png")]),
        preload=preload,
    )

    post_init(p)
    return p
Ejemplo n.º 27
0
 def __init__(self):
     self.files = glob_all([os.path.join(this_dir, "data", "uw3_50lines", "test", "*.png")])
     self.checkpoint = [os.path.join(this_dir, "test_models", "uw3_50lines_best.ckpt")]
     self.processes = 1
     self.batch_size = 1
     self.verbose = True
     self.voter = "confidence_voter_default_ctc"
     self.output_dir = None
     self.extended_prediction_data = None
     self.extended_prediction_data_format = "json"
     self.no_progress_bars = True
Ejemplo n.º 28
0
 def test_prediction_files(self):
     run_predict(
         predict_args(data=FileDataParams(images=sorted(
             glob_all([
                 os.path.join(this_dir, "data", "uw3_50lines", "test",
                              "*.png")
             ])))))
     run_eval(
         eval_args(gt_data=FileDataParams(texts=sorted(
             glob_all([
                 os.path.join(this_dir, "data", "uw3_50lines", "test",
                              "*.gt.txt")
             ])))))
     args = eval_args(gt_data=FileDataParams(texts=sorted(
         glob_all([
             os.path.join(this_dir, "data", "uw3_50lines", "test",
                          "*.gt.txt")
         ]))))
     with tempfile.TemporaryDirectory() as d:
         args.xlsx_output = os.path.join(d, 'output.xlsx')
         run_eval(args)
Ejemplo n.º 29
0
    def __post_init__(self):
        # parse whitelist
        if len(self.include) == 1:
            include = set(self.include[0])
        else:
            include = set(self.include)

        for f in glob_all(self.include_files):
            with open(f) as txt:
                include = include.union(txt.read())

        self.resolved_include_chars = include
Ejemplo n.º 30
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--files",
                        type=str,
                        nargs="+",
                        required=True,
                        help="Text files to apply text processing")
    parser.add_argument("--line_height",
                        type=int,
                        default=48,
                        help="The line height")
    parser.add_argument("--pad",
                        type=int,
                        default=16,
                        help="Padding (left right) of the line")
    parser.add_argument("--pad_value",
                        type=int,
                        default=1,
                        help="Padding (left right) of the line")
    parser.add_argument("--processes", type=int, default=1)
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--invert", action="store_true")
    parser.add_argument("--transpose", action="store_true")
    parser.add_argument("--dry_run",
                        action="store_true",
                        help="No not overwrite files, just run")

    args = parser.parse_args()

    params = DataPreprocessorParams()
    params.line_height = args.line_height
    params.pad = args.pad
    params.pad_value = args.pad_value
    params.no_invert = not args.invert
    params.no_transpos = not args.transpose

    data_proc = MultiDataProcessor([
        DataRangeNormalizer(),
        CenterNormalizer(params),
        FinalPreparation(params, as_uint8=True),
    ])

    print("Resolving files")
    img_files = sorted(glob_all(args.files))

    handler = Handler(data_proc, args.dry_run)

    with multiprocessing.Pool(processes=args.processes,
                              maxtasksperchild=100) as pool:
        list(
            tqdm(pool.imap(handler.handle_single, img_files),
                 desc="Processing",
                 total=len(img_files)))
Ejemplo n.º 31
0
def main():
    parser = argparse.ArgumentParser(description=usage_str)
    parser.add_argument('--checkpoints', nargs='+', type=str, required=True)
    parser.add_argument('--replace_from')
    parser.add_argument('--replace_to')
    parser.add_argument('--add_prefix')
    parser.add_argument('--dry_run', action='store_true')

    args = parser.parse_args()

    for ckpt in tqdm(glob_all(args.checkpoints)):
        ckpt = os.path.splitext(ckpt)[0]
        rename(ckpt, args.replace_from, args.replace_to, args.add_prefix, args.dry_run)
Ejemplo n.º 32
0
def create_train_dataset(cfg: CfgNode, dataset_args=None):
    gt_extension = cfg.DATASET.TRAIN.GT_EXTENSION if cfg.DATASET.TRAIN.GT_EXTENSION is not False else DataSetType.gt_extension(
        cfg.DATASET.TRAIN.TYPE)

    # Training dataset
    print("Resolving input files")
    input_image_files = sorted(glob_all(cfg.DATASET.TRAIN.PATH))
    if not cfg.DATASET.TRAIN.TEXT_FILES:
        if gt_extension:
            gt_txt_files = [
                split_all_ext(f)[0] + gt_extension for f in input_image_files
            ]
        else:
            gt_txt_files = [None] * len(input_image_files)
    else:
        gt_txt_files = sorted(glob_all(cfg.DATASET.TRAIN.TEXT_FILES))
        input_image_files, gt_txt_files = keep_files_with_same_file_name(
            input_image_files, gt_txt_files)
        for img, gt in zip(input_image_files, gt_txt_files):
            if split_all_ext(os.path.basename(img))[0] != split_all_ext(
                    os.path.basename(gt))[0]:
                raise Exception(
                    "Expected identical basenames of file: {} and {}".format(
                        img, gt))

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception(
            "Some image are occurring more than once in the data set.")

    dataset = create_dataset(
        cfg.DATASET.TRAIN.TYPE,
        DataSetMode.TRAIN,
        images=input_image_files,
        texts=gt_txt_files,
        skip_invalid=not cfg.DATALOADER.NO_SKIP_INVALID_GT,
        args=dataset_args if dataset_args else {},
    )
    print("Found {} files in the dataset".format(len(dataset)))
    return dataset
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--files",
                        type=str,
                        nargs="+",
                        required=True,
                        help="Text files to apply text processing")
    parser.add_argument("--text_regularization",
                        type=str,
                        nargs="+",
                        default=["extended"],
                        help="Text regularization to apply.")
    parser.add_argument(
        "--text_normalization",
        type=str,
        default="NFC",
        help="Unicode text normalization to apply. Defaults to NFC")
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--dry_run",
                        action="store_true",
                        help="No not overwrite files, just run")

    args = parser.parse_args()

    # Text pre processing (reading)
    preproc = TextProcessorParams()
    preproc.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(preproc.children.add(),
                                   default=args.text_normalization)
    default_text_regularizer_params(preproc.children.add(),
                                    groups=args.text_regularization)
    strip_processor_params = preproc.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    txt_proc = text_processor_from_proto(preproc, "pre")

    print("Resolving files")
    text_files = glob_all(args.files)

    for path in tqdm(text_files, desc="Processing", total=len(text_files)):
        with codecs.open(path, "r", "utf-8") as f:
            content = f.read()

        content = txt_proc.apply(content)

        if args.verbose:
            print(content)

        if not args.dry_run:
            with codecs.open(path, "w", "utf-8") as f:
                f.write(content)
def main():
    parser = ArgumentParser()
    parser.add_argument("--pred", nargs="+", required=True,
                        help="Extended prediction files (.json extension)")

    args = parser.parse_args()

    print("Resolving files")
    pred_files = sorted(glob_all(args.pred))

    data_set = create_dataset(
        DataSetType.EXTENDED_PREDICTION,
        DataSetMode.EVAL,
        texts=pred_files,
    )

    data_set.load_samples(progress_bar=True)
    print('Average confidence: {:.2%}'.format(np.mean([s['best_prediction'].avg_char_probability for s in data_set.samples()])))
Ejemplo n.º 35
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--files", type=str, nargs="+", required=True,
                        help="Text files to apply text processing")
    parser.add_argument("--line_height", type=int, default=48,
                        help="The line height")
    parser.add_argument("--pad", type=int, default=16,
                        help="Padding (left right) of the line")
    parser.add_argument("--pad_value", type=int, default=1,
                        help="Padding (left right) of the line")
    parser.add_argument("--processes", type=int, default=1)
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--invert", action="store_true")
    parser.add_argument("--transpose", action="store_true")
    parser.add_argument("--dry_run", action="store_true",
                        help="No not overwrite files, just run")

    args = parser.parse_args()

    params = DataPreprocessorParams()
    params.line_height = args.line_height
    params.pad = args.pad
    params.pad_value = args.pad_value
    params.no_invert = not args.invert
    params.no_transpos = not args.transpose

    data_proc = MultiDataProcessor([
        DataRangeNormalizer(),
        CenterNormalizer(params),
        FinalPreparation(params, as_uint8=True),
    ])

    print("Resolving files")
    img_files = sorted(glob_all(args.files))

    handler = Handler(data_proc, args.dry_run)

    with multiprocessing.Pool(processes=args.processes, maxtasksperchild=100) as pool:
        list(tqdm(pool.imap(handler.handle_single, img_files), desc="Processing", total=len(img_files)))
Ejemplo n.º 36
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--files", type=str, default=[], nargs="+", required=True,
                        help="Protobuf files to convert")
    parser.add_argument("--logits", action="store_true",
                        help="Do write logits")
    args = parser.parse_args()

    files = glob_all(args.files)
    for file in tqdm(files, desc="Converting"):
        predictions = Predictions()
        with open(file, 'rb') as f:
            predictions.ParseFromString(f.read())

        if not args.logits:
            for prediction in predictions.predictions:
                prediction.logits.rows = 0
                prediction.logits.cols = 0
                prediction.logits.data[:] = []

        out_json_path = split_all_ext(file)[0] + ".json"
        with open(out_json_path, 'w') as f:
            f.write(MessageToJson(predictions, including_default_value_fields=True))
Ejemplo n.º 37
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--files", nargs="+", required=True,
                        help="All img files, an appropriate .gt.txt must exist")
    parser.add_argument("--n_eval", type=float, required=True,
                        help="The (relative or absolute) count of training files (or -1 to use the remaining)")
    parser.add_argument("--n_train", type=float, required=True,
                        help="The (relative or absolute) count of training files (or -1 to use the remaining)")
    parser.add_argument("--output_dir", type=str, required=True,
                        help="Where to write the splits")
    parser.add_argument("--eval_sub_dir", type=str, default="eval")
    parser.add_argument("--train_sub_dir", type=str, default="train")

    args = parser.parse_args()

    img_files = sorted(glob_all(args.files))
    if len(img_files) == 0:
        raise Exception("No files were found")

    gt_txt_files = [split_all_ext(p)[0] + ".gt.txt" for p in img_files]

    if args.n_eval < 0:
        pass
    elif args.n_eval < 1:
        args.n_eval = int(args.n_eval) * len(img_files)
    else:
        args.n_eval = int(args.n_eval)

    if args.n_train < 0:
        pass
    elif args.n_train < 1:
        args.n_train = int(args.n_train) * len(img_files)
    else:
        args.n_train = int(args.n_train)

    if args.n_eval < 0 and args.n_train < 0:
        raise Exception("Either n_eval or n_train may be < 0")

    if args.n_eval < 0:
        args.n_eval = len(img_files) - args.n_train
    elif args.n_train < 0:
        args.n_train = len(img_files) - args.n_eval

    if args.n_eval + args.n_train > len(img_files):
        raise Exception("Got {} eval and {} train files = {} in total, but only {} files are in the dataset".format(
            args.n_eval, args.n_train, args.n_eval + args.n_train, len(img_files)
        ))

    def copy_files(imgs, txts, out_dir):
        assert(len(imgs) == len(txts))

        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

        for img, txt in tqdm(zip(imgs, txts), total=len(imgs), desc="Writing to {}".format(out_dir)):
            if not os.path.exists(img):
                print("Image file at {} not found".format(img))
                continue

            if not os.path.exists(txt):
                print("Ground truth file at {} not found".format(txt))
                continue

            shutil.copyfile(img, os.path.join(out_dir, os.path.basename(img)))
            shutil.copyfile(txt, os.path.join(out_dir, os.path.basename(txt)))

    copy_files(img_files[:args.n_eval], gt_txt_files[:args.n_eval], os.path.join(args.output_dir, args.eval_sub_dir))
    copy_files(img_files[args.n_eval:], gt_txt_files[args.n_eval:], os.path.join(args.output_dir, args.train_sub_dir))
Ejemplo n.º 38
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--checkpoint", type=str, required=True,
                        help="The checkpoint used to resume")

    # validation files
    parser.add_argument("--validation", type=str, nargs="+",
                        help="Validation line files used for early stopping")
    parser.add_argument("--validation_text_files", nargs="+", default=None,
                        help="Optional list of validation GT files if they are in other directory")
    parser.add_argument("--validation_extension", default=None,
                        help="Default extension of the gt files (expected to exist in same dir)")
    parser.add_argument("--validation_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)

    # input files
    parser.add_argument("--files", nargs="+",
                        help="List all image files that shall be processed. Ground truth fils with the same "
                             "base name but with '.gt.txt' as extension are required at the same location")
    parser.add_argument("--text_files", nargs="+", default=None,
                        help="Optional list of GT files if they are in other directory")
    parser.add_argument("--gt_extension", default=None,
                        help="Default extension of the gt files (expected to exist in same dir)")
    parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--no_skip_invalid_gt", action="store_true",
                        help="Do no skip invalid gt, instead raise an exception.")

    args = parser.parse_args()

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    if args.validation_extension is None:
        args.validation_extension = DataSetType.gt_extension(args.validation_dataset)

    # Training dataset
    print("Resolving input files")
    input_image_files = sorted(glob_all(args.files))
    if not args.text_files:
        gt_txt_files = [split_all_ext(f)[0] + args.gt_extension for f in input_image_files]
    else:
        gt_txt_files = sorted(glob_all(args.text_files))
        input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files)
        for img, gt in zip(input_image_files, gt_txt_files):
            if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                raise Exception("Expected identical basenames of file: {} and {}".format(img, gt))

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception("Some image are occurring more than once in the data set.")

    dataset = create_dataset(
        args.dataset,
        DataSetMode.TRAIN,
        images=input_image_files,
        texts=gt_txt_files,
        skip_invalid=not args.no_skip_invalid_gt
    )
    print("Found {} files in the dataset".format(len(dataset)))

    # Validation dataset
    if args.validation:
        print("Resolving validation files")
        validation_image_files = glob_all(args.validation)
        if not args.validation_text_files:
            val_txt_files = [split_all_ext(f)[0] + args.validation_extension for f in validation_image_files]
        else:
            val_txt_files = sorted(glob_all(args.validation_text_files))
            validation_image_files, val_txt_files = keep_files_with_same_file_name(validation_image_files, val_txt_files)
            for img, gt in zip(validation_image_files, val_txt_files):
                if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                    raise Exception("Expected identical basenames of validation file: {} and {}".format(img, gt))

        if len(set(val_txt_files)) != len(val_txt_files):
            raise Exception("Some validation images are occurring more than once in the data set.")

        validation_dataset = create_dataset(
            args.validation_dataset,
            DataSetMode.TRAIN,
            images=validation_image_files,
            texts=val_txt_files,
            skip_invalid=not args.no_skip_invalid_gt)
        print("Found {} files in the validation dataset".format(len(validation_dataset)))
    else:
        validation_dataset = None

    print("Resuming training")
    with open(args.checkpoint + '.json', 'r') as f:
        checkpoint_params = json_format.Parse(f.read(), CheckpointParams())

        trainer = Trainer(checkpoint_params, dataset,
                          validation_dataset=validation_dataset,
                          weights=args.checkpoint)
        trainer.train(progress_bar=True)
Ejemplo n.º 39
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--eval_imgs", type=str, nargs="+", required=True,
                        help="The evaluation files")
    parser.add_argument("--eval_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--checkpoint", type=str, nargs="+", default=[],
                        help="Path to the checkpoint without file extension")
    parser.add_argument("-j", "--processes", type=int, default=1,
                        help="Number of processes to use")
    parser.add_argument("--verbose", action="store_true",
                        help="Print additional information")
    parser.add_argument("--voter", type=str, nargs="+", default=["sequence_voter", "confidence_voter_default_ctc", "confidence_voter_fuzzy_ctc"],
                        help="The voting algorithm to use. Possible values: confidence_voter_default_ctc (default), "
                             "confidence_voter_fuzzy_ctc, sequence_voter")
    parser.add_argument("--batch_size", type=int, default=10,
                        help="The batch size for prediction")
    parser.add_argument("--dump", type=str,
                        help="Dump the output as serialized pickle object")
    parser.add_argument("--no_skip_invalid_gt", action="store_true",
                        help="Do no skip invalid gt, instead raise an exception.")

    args = parser.parse_args()

    # allow user to specify json file for model definition, but remove the file extension
    # for further processing
    args.checkpoint = [(cp[:-5] if cp.endswith(".json") else cp) for cp in args.checkpoint]

    # load files
    gt_images = sorted(glob_all(args.eval_imgs))
    gt_txts = [split_all_ext(path)[0] + ".gt.txt" for path in sorted(glob_all(args.eval_imgs))]

    dataset = create_dataset(
        args.eval_dataset,
        DataSetMode.TRAIN,
        images=gt_images,
        texts=gt_txts,
        skip_invalid=not args.no_skip_invalid_gt
    )

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception("Empty dataset provided. Check your files argument (got {})!".format(args.files))

    # predict for all models
    n_models = len(args.checkpoint)
    predictor = MultiPredictor(checkpoints=args.checkpoint, batch_size=args.batch_size, processes=args.processes)
    do_prediction = predictor.predict_dataset(dataset, progress_bar=True)

    voters = []
    all_voter_sentences = []
    all_prediction_sentences = [[] for _ in range(n_models)]

    for voter in args.voter:
        # create voter
        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value(voter.upper())
        voters.append(voter_from_proto(voter_params))
        all_voter_sentences.append([])

    for prediction, sample in do_prediction:
        for sent, p in zip(all_prediction_sentences, prediction):
            sent.append(p.sentence)

        # vote results
        for voter, voter_sentences in zip(voters, all_voter_sentences):
            voter_sentences.append(voter.vote_prediction_result(prediction).sentence)

    # evaluation
    text_preproc = text_processor_from_proto(predictor.predictors[0].model_params.text_preprocessor)
    evaluator = Evaluator(text_preprocessor=text_preproc)
    evaluator.preload_gt(gt_dataset=dataset, progress_bar=True)

    def single_evaluation(predicted_sentences):
        if len(predicted_sentences) != len(dataset):
            raise Exception("Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did "
                            "not succeed".format(len(dataset), len(predicted_sentences)))

        pred_data_set = create_dataset(
            DataSetType.RAW,
            DataSetMode.EVAL,
            texts=predicted_sentences)

        r = evaluator.run(pred_dataset=pred_data_set, progress_bar=True, processes=args.processes)

        return r

    full_evaluation = {}
    for id, data in [(str(i), sent) for i, sent in enumerate(all_prediction_sentences)] + list(zip(args.voter, all_voter_sentences)):
        full_evaluation[id] = {"eval": single_evaluation(data), "data": data}

    if args.verbose:
        print(full_evaluation)

    if args.dump:
        import pickle
        with open(args.dump, 'wb') as f:
            pickle.dump({"full": full_evaluation, "gt_txts": gt_txts, "gt": dataset.text_samples()}, f)
Ejemplo n.º 40
0
def run(args):

    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        import json
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    # parse whitelist
    whitelist = args.whitelist
    if len(whitelist) == 1:
        whitelist = list(whitelist[0])

    whitelist_files = glob_all(args.whitelist_files)
    for f in whitelist_files:
        with open(f) as txt:
            whitelist += list(txt.read())

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    if args.validation_extension is None:
        args.validation_extension = DataSetType.gt_extension(args.validation_dataset)

    # Training dataset
    print("Resolving input files")
    input_image_files = sorted(glob_all(args.files))
    if not args.text_files:
        gt_txt_files = [split_all_ext(f)[0] + args.gt_extension for f in input_image_files]
    else:
        gt_txt_files = sorted(glob_all(args.text_files))
        input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files)
        for img, gt in zip(input_image_files, gt_txt_files):
            if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                raise Exception("Expected identical basenames of file: {} and {}".format(img, gt))

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception("Some image are occurring more than once in the data set.")

    dataset = create_dataset(
        args.dataset,
        DataSetMode.TRAIN,
        images=input_image_files,
        texts=gt_txt_files,
        skip_invalid=not args.no_skip_invalid_gt
    )
    print("Found {} files in the dataset".format(len(dataset)))

    # Validation dataset
    if args.validation:
        print("Resolving validation files")
        validation_image_files = glob_all(args.validation)
        if not args.validation_text_files:
            val_txt_files = [split_all_ext(f)[0] + args.validation_extension for f in validation_image_files]
        else:
            val_txt_files = sorted(glob_all(args.validation_text_files))
            validation_image_files, val_txt_files = keep_files_with_same_file_name(validation_image_files, val_txt_files)
            for img, gt in zip(validation_image_files, val_txt_files):
                if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                    raise Exception("Expected identical basenames of validation file: {} and {}".format(img, gt))

        if len(set(val_txt_files)) != len(val_txt_files):
            raise Exception("Some validation images are occurring more than once in the data set.")

        validation_dataset = create_dataset(
            args.validation_dataset,
            DataSetMode.TRAIN,
            images=validation_image_files,
            texts=val_txt_files,
            skip_invalid=not args.no_skip_invalid_gt)
        print("Found {} files in the validation dataset".format(len(validation_dataset)))
    else:
        validation_dataset = None

    params = CheckpointParams()

    params.max_iters = args.max_iters
    params.stats_size = args.stats_size
    params.batch_size = args.batch_size
    params.checkpoint_frequency = args.checkpoint_frequency if args.checkpoint_frequency >= 0 else args.early_stopping_frequency
    params.output_dir = args.output_dir
    params.output_model_prefix = args.output_model_prefix
    params.display = args.display
    params.skip_invalid_gt = not args.no_skip_invalid_gt
    params.processes = args.num_threads
    params.data_aug_retrain_on_original = not args.only_train_on_augmented

    params.early_stopping_frequency = args.early_stopping_frequency
    params.early_stopping_nbest = args.early_stopping_nbest
    params.early_stopping_best_model_prefix = args.early_stopping_best_model_prefix
    params.early_stopping_best_model_output_dir = \
        args.early_stopping_best_model_output_dir if args.early_stopping_best_model_output_dir else args.output_dir

    params.model.data_preprocessor.type = DataPreprocessorParams.DEFAULT_NORMALIZER
    params.model.data_preprocessor.line_height = args.line_height
    params.model.data_preprocessor.pad = args.pad

    # Text pre processing (reading)
    params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(params.model.text_preprocessor.children.add(), default=args.text_normalization)
    default_text_regularizer_params(params.model.text_preprocessor.children.add(), groups=args.text_regularization)
    strip_processor_params = params.model.text_preprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    # Text post processing (prediction)
    params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(params.model.text_postprocessor.children.add(), default=args.text_normalization)
    default_text_regularizer_params(params.model.text_postprocessor.children.add(), groups=args.text_regularization)
    strip_processor_params = params.model.text_postprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    if args.seed > 0:
        params.model.network.backend.random_seed = args.seed

    if args.bidi_dir:
        # change bidirectional text direction if desired
        bidi_dir_to_enum = {"rtl": TextProcessorParams.BIDI_RTL, "ltr": TextProcessorParams.BIDI_LTR,
                            "auto": TextProcessorParams.BIDI_AUTO}

        bidi_processor_params = params.model.text_preprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir]

        bidi_processor_params = params.model.text_postprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO

    params.model.line_height = args.line_height

    network_params_from_definition_string(args.network, params.model.network)
    params.model.network.clipping_mode = NetworkParams.ClippingMode.Value("CLIP_" + args.gradient_clipping_mode.upper())
    params.model.network.clipping_constant = args.gradient_clipping_const
    params.model.network.backend.fuzzy_ctc_library_path = args.fuzzy_ctc_library_path
    params.model.network.backend.num_inter_threads = args.num_inter_threads
    params.model.network.backend.num_intra_threads = args.num_intra_threads

    # create the actual trainer
    trainer = Trainer(params,
                      dataset,
                      validation_dataset=validation_dataset,
                      data_augmenter=SimpleDataAugmenter(),
                      n_augmentations=args.n_augmentations,
                      weights=args.weights,
                      codec_whitelist=whitelist,
                      preload_training=not args.train_data_on_the_fly,
                      preload_validation=not args.validation_data_on_the_fly,
                      )
    trainer.train(
        auto_compute_codec=not args.no_auto_compute_codec,
        progress_bar=not args.no_progress_bars
    )
Ejemplo n.º 41
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--gt", nargs="+", required=True,
                        help="Ground truth files (.gt.txt extension)")
    parser.add_argument("--pred", nargs="+", default=None,
                        help="Prediction files if provided. Else files with .pred.txt are expected at the same "
                             "location as the gt.")
    parser.add_argument("--pred_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--pred_ext", type=str, default=".pred.txt",
                        help="Extension of the predicted text files")
    parser.add_argument("--n_confusions", type=int, default=10,
                        help="Only print n most common confusions. Defaults to 10, use -1 for all.")
    parser.add_argument("--n_worst_lines", type=int, default=0,
                        help="Print the n worst recognized text lines with its error")
    parser.add_argument("--xlsx_output", type=str,
                        help="Optionally write a xlsx file with the evaluation results")
    parser.add_argument("--num_threads", type=int, default=1,
                        help="Number of threads to use for evaluation")
    parser.add_argument("--non_existing_file_handling_mode", type=str, default="error",
                        help="How to handle non existing .pred.txt files. Possible modes: skip, empty, error. "
                             "'Skip' will simply skip the evaluation of that file (not counting it to errors). "
                             "'Empty' will handle this file as would it be empty (fully checking for errors)."
                             "'Error' will throw an exception if a file is not existing. This is the default behaviour.")
    parser.add_argument("--no_progress_bars", action="store_true",
                        help="Do not show any progress bars")
    parser.add_argument("--checkpoint", type=str, default=None,
                        help="Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)")

    # page xml specific args
    parser.add_argument("--pagexml_gt_text_index", default=0)
    parser.add_argument("--pagexml_pred_text_index", default=1)


    args = parser.parse_args()

    print("Resolving files")
    gt_files = sorted(glob_all(args.gt))

    if args.pred:
        pred_files = sorted(glob_all(args.pred))
    else:
        pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files]
        args.pred_dataset = args.dataset

    if args.non_existing_file_handling_mode.lower() == "skip":
        non_existing_pred = [p for p in pred_files if not os.path.exists(p)]
        for f in non_existing_pred:
            idx = pred_files.index(f)
            del pred_files[idx]
            del gt_files[idx]

    text_preproc = None
    if args.checkpoint:
        with open(args.checkpoint if args.checkpoint.endswith(".json") else args.checkpoint + '.json', 'r') as f:
            checkpoint_params = json_format.Parse(f.read(), CheckpointParams())
            text_preproc = text_processor_from_proto(checkpoint_params.model.text_preprocessor)

    non_existing_as_empty = args.non_existing_file_handling_mode.lower() != "error "
    gt_data_set = create_dataset(
        args.dataset,
        DataSetMode.EVAL,
        texts=gt_files,
        non_existing_as_empty=non_existing_as_empty,
        args={'text_index': args.pagexml_gt_text_index},
    )
    pred_data_set = create_dataset(
        args.pred_dataset,
        DataSetMode.EVAL,
        texts=pred_files,
        non_existing_as_empty=non_existing_as_empty,
        args={'text_index': args.pagexml_pred_text_index},
    )

    evaluator = Evaluator(text_preprocessor=text_preproc)
    r = evaluator.run(gt_dataset=gt_data_set, pred_dataset=pred_data_set, processes=args.num_threads,
                      progress_bar=not args.no_progress_bars)

    # TODO: More output
    print("Evaluation result")
    print("=================")
    print("")
    print("Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)".format(
        r["avg_ler"], r["total_char_errs"], r["total_chars"], r["total_sync_errs"]))

    # sort descending
    print_confusions(r, args.n_confusions)

    print_worst_lines(r, gt_data_set.samples(), pred_data_set.text_samples(), args.n_worst_lines)

    if args.xlsx_output:
        write_xlsx(args.xlsx_output,
                   [{
                       "prefix": "evaluation",
                       "results": r,
                       "gt_files": gt_files,
                       "gts": gt_data_set.text_samples(),
                       "preds": pred_data_set.text_samples()
                   }])
Ejemplo n.º 42
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_dir", type=str, required=True,
                        help="The base directory where to store all working files")
    parser.add_argument("--eval_files", type=str, nargs="+", required=True,
                        help="All files that shall be used for evaluation")
    parser.add_argument("--train_files", type=str, nargs="+", required=True,
                        help="All files that shall be used for (cross-fold) training")
    parser.add_argument("--n_lines", type=int, default=[-1], nargs="+",
                        help="Optional argument to specify the number of lines (images) used for training. "
                             "On default, all available lines will be used.")
    parser.add_argument("--run", type=str, default=None,
                        help="An optional command that will receive the train calls. Useful e.g. when using a resource "
                             "manager such as slurm.")

    parser.add_argument("--n_folds", type=int, default=5,
                        help="The number of fold, that is the number of models to train")
    parser.add_argument("--max_parallel_models", type=int, default=-1,
                        help="Number of models to train in parallel per fold. Defaults to all.")
    parser.add_argument("--weights", type=str, nargs="+", default=[],
                        help="Load network weights from the given file. If more than one file is provided the number "
                             "models must match the number of folds. Each fold is then initialized with the weights "
                             "of each model, respectively")
    parser.add_argument("--single_fold", type=int, nargs="+", default=[],
                        help="Only train a single (list of single) specific fold(s).")
    parser.add_argument("--skip_train", action="store_true",
                        help="Skip the cross fold training")
    parser.add_argument("--skip_eval", action="store_true",
                        help="Skip the cross fold evaluation")
    parser.add_argument("--verbose", action="store_true",
                        help="Verbose output")
    parser.add_argument("--n_confusions", type=int, default=0,
                        help="Only print n most common confusions. Defaults to 0, use -1 for all.")
    parser.add_argument("--xlsx_output", type=str,
                        help="Optionally write a xlsx file with the evaluation results")

    setup_train_args(parser, omit=["files", "validation", "weights",
                                   "early_stopping_best_model_output_dir", "early_stopping_best_model_prefix",
                                   "output_dir"])

    args = parser.parse_args()

    args.base_dir = os.path.abspath(os.path.expanduser(args.base_dir))

    np.random.seed(args.seed)
    random.seed(args.seed)

    # argument checks
    args.weights = glob_all(args.weights)
    if len(args.weights) > 1 and len(args.weights) != args.n_folds:
        raise Exception("Either no, one or n_folds (={}) models are required for pretraining but got {}.".format(
            args.n_folds, len(args.weights)
        ))

    if len(args.single_fold) > 0:
        if len(set(args.single_fold)) != len(args.single_fold):
            raise Exception("Repeated fold id's found.")
        for fold_id in args.single_fold:
            if fold_id < 0 or fold_id >= args.n_folds:
                raise Exception("Invalid fold id found: 0 <= id <= {}, but id == {}".format(args.n_folds, fold_id))

        actual_folds = args.single_fold
    else:
        actual_folds = list(range(args.n_folds))

    # run for all lines
    single_args = [copy.copy(args) for _ in args.n_lines]
    for s_args, n_lines in zip(single_args, args.n_lines):
        s_args.n_lines = n_lines

    predictions = parallel_map(run_for_single_line, single_args, progress_bar=False, processes=len(single_args), use_thread_pool=True)

    # output predictions as csv:
    header = "lines," + ",".join([str(fold) for fold in range(args.n_folds)])\
             + ",avg,std,seq. vot., def. conf. vot., fuz. conf. vot."

    print(header)

    for prediction_map, n_lines in zip(predictions, args.n_lines):
        prediction = prediction_map["full"]
        data = "{}".format(n_lines)
        folds_lers = []
        for fold in range(len(actual_folds)):
            eval = prediction[str(fold)]["eval"]
            data += ",{}".format(eval['avg_ler'])
            folds_lers.append(eval['avg_ler'])

        data += ",{},{}".format(np.mean(folds_lers), np.std(folds_lers))
        for voter in ['sequence_voter', 'confidence_voter_default_ctc', 'confidence_voter_fuzzy_ctc']:
            eval = prediction[voter]["eval"]
            data += ",{}".format(eval['avg_ler'])

        print(data)

    if args.n_confusions != 0:
        for prediction_map, n_lines in zip(predictions, args.n_lines):
            prediction = prediction_map["full"]
            print("")
            print("CONFUSIONS (lines = {})".format(n_lines))
            print("==========")
            print()

            for fold in range(len(actual_folds)):
                print("FOLD {}".format(fold))
                print_confusions(prediction[str(fold)]['eval'], args.n_confusions)

            for voter in ['sequence_voter', 'confidence_voter_default_ctc', 'confidence_voter_fuzzy_ctc']:
                print("VOTER {}".format(voter))
                print_confusions(prediction[voter]['eval'], args.n_confusions)

    if args.xlsx_output:
        data_list = []
        for prediction_map, n_lines in zip(predictions, args.n_lines):
            prediction = prediction_map["full"]
            for fold in actual_folds:
                pred = prediction[str(fold)]
                data_list.append({
                    "prefix": "L{} - Fold{}".format(n_lines, fold),
                    "results": pred['eval'],
                    "gt_files": prediction_map['gt_txts'],
                    "gts": prediction_map['gt'],
                    "preds": pred['data']
                })

            for voter in ['sequence_voter', 'confidence_voter_default_ctc']:
                pred = prediction[voter]
                data_list.append({
                    "prefix": "L{} - {}".format(n_lines, voter[:3]),
                    "results": pred['eval'],
                    "gt_files": prediction_map['gt_txts'],
                    "gts": prediction_map['gt'],
                    "preds": pred['data']
                })

        write_xlsx(args.xlsx_output, data_list)
Ejemplo n.º 43
0
def run_for_single_line(args):
    # lines/network/pretraining as base dir
    args.base_dir = os.path.join(args.base_dir, "all" if args.n_lines < 0 else str(args.n_lines))
    pretrain_prefix = "scratch"
    if args.weights and len(args.weights) > 0:
        pretrain_prefix = ",".join([split_all_ext(os.path.basename(path))[0] for path in args.weights])

    args.base_dir = os.path.join(args.base_dir, args.network, pretrain_prefix)

    if not os.path.exists(args.base_dir):
        os.makedirs(args.base_dir)

    tmp_dir = os.path.join(args.base_dir, "tmp")
    if not os.path.exists(tmp_dir):
        os.makedirs(tmp_dir)

    best_models_dir = os.path.join(args.base_dir, "models")
    if not os.path.exists(best_models_dir):
        os.makedirs(best_models_dir)

    prediction_dir = os.path.join(args.base_dir, "predictions")
    if not os.path.exists(prediction_dir):
        os.makedirs(prediction_dir)

    # select number of files
    files = args.train_files
    if args.n_lines > 0:
        all_files = glob_all(args.train_files)
        files = random.sample(all_files, args.n_lines)

    # run the cross-fold-training
    setattr(args, "max_parallel_models", args.max_parallel_models)
    setattr(args, "best_models_dir", best_models_dir)
    setattr(args, "temporary_dir", tmp_dir)
    setattr(args, "keep_temporary_files", False)
    setattr(args, "files", files)
    setattr(args, "best_model_label", "{id}")
    if not args.skip_train:
        cross_fold_train.main(args)

    dump_file = os.path.join(tmp_dir, "prediction.pkl")

    # run the prediction
    if not args.skip_eval:
        # locate the eval script (must be in the same dir as "this")
        predict_script_path = os.path.join(this_absdir, "experiment_eval.py")

        if len(args.single_fold) > 0:
            models = [os.path.join(best_models_dir, "{}.ckpt.json".format(sf)) for sf in args.single_fold]
            for m in models:
                if not os.path.exists(m):
                    raise Exception("Expected model at '{}', but file does not exist".format(m))
        else:
            models = [os.path.join(best_models_dir, d) for d in sorted(os.listdir(best_models_dir)) if d.endswith("json")]
            if len(models) != args.n_folds:
                raise Exception("Expected {} models, one for each fold respectively, but only {} models were found".format(
                    args.n_folds, len(models)
                ))

        for line in run(prefix_run_command([
                "python3", "-u",
                predict_script_path,
                "-j", str(args.num_threads),
                "--batch_size", str(args.batch_size),
                "--dump", dump_file,
                "--eval_imgs"] + args.eval_files + [
                ] + (["--verbose"] if args.verbose else []) + [
                "--checkpoint"] + models + [
                ], args.run, {"threads": args.num_threads}), verbose=args.verbose):
            # Print the output of the thread
            if args.verbose:
                print(line)

    import pickle
    with open(dump_file, 'rb') as f:
        prediction = pickle.load(f)

    return prediction