Exemple #1
0
    def prepare_for_mode(self, mode: PipelineMode):
        logger.info("Resolving input files")
        input_image_files = sorted(glob_all(self.images))

        if not self.texts:
            gt_txt_files = [split_all_ext(f)[0] + self.gt_extension for f in input_image_files]
        else:
            gt_txt_files = sorted(glob_all(self.texts))
            if mode in INPUT_PROCESSOR:
                input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files)
                for img, gt in zip(input_image_files, gt_txt_files):
                    if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                        raise Exception(f"Expected identical basenames of file: {img} and {gt}")
            else:
                input_image_files = None

        if mode in {PipelineMode.TRAINING, PipelineMode.EVALUATION}:
            if len(set(gt_txt_files)) != len(gt_txt_files):
                logger.warning(
                    "Some ground truth text files occur more than once in the data set "
                    "(ignore this warning, if this was intended)."
                )
            if len(set(input_image_files)) != len(input_image_files):
                logger.warning(
                    "Some images occur more than once in the data set. " "This warning should usually not be ignored."
                )

        self.images = input_image_files
        self.texts = gt_txt_files
Exemple #2
0
def create_train_dataset(args, dataset_args=None):
    gt_extension = args.gt_extension if args.gt_extension is not None else DataSetType.gt_extension(args.dataset)

    # Training dataset
    print("Resolving input files")
    input_image_files = sorted(glob_all(args.files))
    if not args.text_files:
        if gt_extension:
            gt_txt_files = [split_all_ext(f)[0] + gt_extension for f in input_image_files]
        else:
            gt_txt_files = [None] * len(input_image_files)
    else:
        gt_txt_files = sorted(glob_all(args.text_files))
        input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files)
        for img, gt in zip(input_image_files, gt_txt_files):
            if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                raise Exception("Expected identical basenames of file: {} and {}".format(img, gt))

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception("Some image are occurring more than once in the data set.")

    dataset = create_dataset(
        args.dataset,
        DataSetMode.TRAIN,
        images=input_image_files,
        texts=gt_txt_files,
        skip_invalid=not args.no_skip_invalid_gt,
        args=dataset_args if dataset_args else {},
    )
    print("Found {} files in the dataset".format(len(dataset)))
    return dataset
Exemple #3
0
def create_dataset(
    type: DataSetType,
    mode: DataSetMode,
    images: List[str] = None,
    texts: List[str] = None,
    skip_invalid=False,
    remove_invalid=True,
    non_existing_as_empty=False,
    args: dict = None,
):
    if images is None:
        images = []

    if texts is None:
        texts = []

    if args is None:
        args = dict()

    if DataSetType.files(type):
        if images:
            images.sort()

        if texts:
            texts.sort()

        if images and texts and len(images) > 0 and len(texts) > 0:
            images, texts = keep_files_with_same_file_name(images, texts)

    if type == DataSetType.RAW:
        return RawDataSet(mode, images, texts)

    elif type == DataSetType.FILE:
        return FileDataSet(mode,
                           images,
                           texts,
                           skip_invalid=skip_invalid,
                           remove_invalid=remove_invalid,
                           non_existing_as_empty=non_existing_as_empty)
    elif type == DataSetType.ABBYY:
        return AbbyyDataSet(mode,
                            images,
                            texts,
                            skip_invalid=skip_invalid,
                            remove_invalid=remove_invalid,
                            non_existing_as_empty=non_existing_as_empty)
    elif type == DataSetType.PAGEXML:
        return PageXMLDataset(mode,
                              images,
                              texts,
                              skip_invalid=skip_invalid,
                              remove_invalid=remove_invalid,
                              non_existing_as_empty=non_existing_as_empty,
                              args=args)
    elif type == DataSetType.EXTENDED_PREDICTION:
        from .extended_prediction_dataset import ExtendedPredictionDataSet
        return ExtendedPredictionDataSet(texts=texts)
    else:
        raise Exception("Unsupported dataset type {}".format(type))
Exemple #4
0
    def prepare_for_mode(self, mode: PipelineMode) -> 'PipelineParams':
        from calamari_ocr.ocr.dataset.datareader.factory import DataReaderFactory
        assert (self.type is not None)
        params_out = deepcopy(self)
        # Training dataset
        logger.info("Resolving input files")
        if isinstance(self.type, str):
            try:
                self.type = DataSetType.from_string(self.type)
            except ValueError:
                # Not a valid type, must be custom
                if self.type not in DataReaderFactory.CUSTOM_READERS:
                    raise KeyError(
                        f"DataSetType {self.type} is neither a standard DataSetType or preset as custom "
                        f"reader ({list(DataReaderFactory.CUSTOM_READERS.keys())})"
                    )
        if not isinstance(self.type, str) and self.type not in {
                DataSetType.RAW, DataSetType.GENERATED_LINE
        }:
            input_image_files = sorted(glob_all(
                self.files)) if self.files else None

            if not self.text_files:
                if self.gt_extension:
                    gt_txt_files = [
                        split_all_ext(f)[0] + self.gt_extension
                        for f in input_image_files
                    ]
                else:
                    gt_txt_files = None
            else:
                gt_txt_files = sorted(glob_all(self.text_files))
                if mode in INPUT_PROCESSOR:
                    input_image_files, gt_txt_files = keep_files_with_same_file_name(
                        input_image_files, gt_txt_files)
                    for img, gt in zip(input_image_files, gt_txt_files):
                        if split_all_ext(
                                os.path.basename(img))[0] != split_all_ext(
                                    os.path.basename(gt))[0]:
                            raise Exception(
                                "Expected identical basenames of file: {} and {}"
                                .format(img, gt))
                else:
                    input_image_files = None

            if mode in {PipelineMode.Training, PipelineMode.Evaluation}:
                if len(set(gt_txt_files)) != len(gt_txt_files):
                    logger.warning(
                        "Some ground truth text files occur more than once in the data set "
                        "(ignore this warning, if this was intended).")
                if len(set(input_image_files)) != len(input_image_files):
                    logger.warning(
                        "Some images occur more than once in the data set. "
                        "This warning should usually not be ignored.")

            params_out.files = input_image_files
            params_out.text_files = gt_txt_files
        return params_out
Exemple #5
0
def data_reader_from_params(mode: PipelineMode,
                            params: PipelineParams) -> DataReader:
    assert (params.type is not None)
    from calamari_ocr.ocr.dataset.dataset_factory import create_data_reader
    # Training dataset
    logger.info("Resolving input files")
    if params.type not in {DataSetType.RAW, DataSetType.GENERATED_LINE}:
        input_image_files = sorted(glob_all(
            params.files)) if params.files else None

        if not params.text_files:
            if params.gt_extension:
                gt_txt_files = [
                    split_all_ext(f)[0] + params.gt_extension
                    for f in input_image_files
                ]
            else:
                gt_txt_files = None
        else:
            gt_txt_files = sorted(glob_all(params.text_files))
            if mode in INPUT_PROCESSOR:
                input_image_files, gt_txt_files = keep_files_with_same_file_name(
                    input_image_files, gt_txt_files)
                for img, gt in zip(input_image_files, gt_txt_files):
                    if split_all_ext(
                            os.path.basename(img))[0] != split_all_ext(
                                os.path.basename(gt))[0]:
                        raise Exception(
                            "Expected identical basenames of file: {} and {}".
                            format(img, gt))
            else:
                input_image_files = None

        if mode in {PipelineMode.Training, PipelineMode.Evaluation}:
            if len(set(gt_txt_files)) != len(gt_txt_files):
                logger.warning(
                    "Some ground truth text files occur more than once in the data set "
                    "(ignore this warning, if this was intended).")
            if len(set(input_image_files)) != len(input_image_files):
                logger.warning(
                    "Some images occur more than once in the data set. "
                    "This warning should usually not be ignored.")
    else:
        input_image_files = params.files
        gt_txt_files = params.text_files

    dataset = create_data_reader(
        params.type,
        mode,
        images=input_image_files,
        texts=gt_txt_files,
        skip_invalid=params.skip_invalid,
        args=params.data_reader_args
        if params.data_reader_args else FileDataReaderArgs(),
    )
    logger.info(f"Found {len(dataset)} files in the dataset")
    return dataset
Exemple #6
0
def create_test_dataset(
    cfg: CfgNode,
    dataset_args=None
) -> Union[List[Union[RawDataSet, FileDataSet, AbbyyDataSet, PageXMLDataset,
                      Hdf5DataSet, ExtendedPredictionDataSet,
                      GeneratedLineDataset]], None]:
    if cfg.DATASET.VALID.TEXT_FILES:
        assert len(cfg.DATASET.VALID.PATH) == len(cfg.DATASET.VALID.TEXT_FILES)

    if cfg.DATASET.VALID.PATH:
        validation_dataset_list = []
        print("Resolving validation files")
        for i, valid_path in enumerate(cfg.DATASET.VALID.PATH):
            validation_image_files = glob_all(valid_path)
            dataregistry.register(
                i, os.path.basename(os.path.dirname(valid_path)),
                len(validation_image_files))

            if not cfg.DATASET.VALID.TEXT_FILES:
                val_txt_files = [
                    split_all_ext(f)[0] + cfg.DATASET.VALID.GT_EXTENSION
                    for f in validation_image_files
                ]
            else:
                val_txt_files = sorted(
                    glob_all(cfg.DATASET.VALID.TEXT_FILES[i]))
                validation_image_files, val_txt_files = keep_files_with_same_file_name(
                    validation_image_files, val_txt_files)
                for img, gt in zip(validation_image_files, val_txt_files):
                    if split_all_ext(
                            os.path.basename(img))[0] != split_all_ext(
                                os.path.basename(gt))[0]:
                        raise Exception(
                            "Expected identical basenames of validation file: {} and {}"
                            .format(img, gt))

            if len(set(val_txt_files)) != len(val_txt_files):
                raise Exception(
                    "Some validation images are occurring more than once in the data set."
                )

            validation_dataset = create_dataset(
                cfg.DATASET.VALID.TYPE,
                DataSetMode.TRAIN,
                images=validation_image_files,
                texts=val_txt_files,
                skip_invalid=not cfg.DATALOADER.NO_SKIP_INVALID_GT,
                args=dataset_args,
            )
            print("Found {} files in the validation dataset".format(
                len(validation_dataset)))
            validation_dataset_list.append(validation_dataset)
    else:
        validation_dataset_list = None

    return validation_dataset_list
Exemple #7
0
def create_dataset(
        type: DataSetType,
        mode: DataSetMode,
        images=list(),
        texts=list(),
        skip_invalid=False,
        remove_invalid=True,
        non_existing_as_empty=False,
        args=dict(),
):
    if DataSetType.files(type):
        if images:
            images.sort()

        if texts:
            texts.sort()

        if images and texts and len(images) > 0 and len(texts) > 0:
            images, texts = keep_files_with_same_file_name(images, texts)

    if type == DataSetType.RAW:
        return RawDataSet(mode, images, texts)

    elif type == DataSetType.FILE:
        return FileDataSet(mode,
                           images,
                           texts,
                           skip_invalid=skip_invalid,
                           remove_invalid=remove_invalid,
                           non_existing_as_empty=non_existing_as_empty)
    elif type == DataSetType.ABBYY:
        return AbbyyDataSet(mode,
                            images,
                            texts,
                            skip_invalid=skip_invalid,
                            remove_invalid=remove_invalid,
                            non_existing_as_empty=non_existing_as_empty)
    elif type == DataSetType.PAGEXML:
        return PageXMLDataset(mode,
                              images,
                              texts,
                              skip_invalid=skip_invalid,
                              remove_invalid=remove_invalid,
                              non_existing_as_empty=non_existing_as_empty,
                              args=args)
    else:
        raise Exception("Unsuppoted dataset type {}".format(type))
Exemple #8
0
def create_dataset(type: DataSetType,
                   mode: DataSetMode,
                   images = list(),
                   texts = list(),
                   skip_invalid=False,
                   remove_invalid=True,
                   non_existing_as_empty=False,
                   args = dict(),
                   ):
    if DataSetType.files(type):
        if images:
            images.sort()

        if texts:
            texts.sort()

        if images and texts and len(images) > 0 and len(texts) > 0:
            images, texts = keep_files_with_same_file_name(images, texts)

    if type == DataSetType.RAW:
        return RawDataSet(mode, images, texts)

    elif type == DataSetType.FILE:
        return FileDataSet(mode, images, texts,
                           skip_invalid=skip_invalid,
                           remove_invalid=remove_invalid,
                           non_existing_as_empty=non_existing_as_empty)
    elif type == DataSetType.ABBYY:
        return AbbyyDataSet(mode, images, texts,
                            skip_invalid=skip_invalid,
                            remove_invalid=remove_invalid,
                            non_existing_as_empty=non_existing_as_empty)
    elif type == DataSetType.PAGEXML:
        return PageXMLDataset(mode, images, texts,
                              skip_invalid=skip_invalid,
                              remove_invalid=remove_invalid,
                              non_existing_as_empty=non_existing_as_empty,
                              args=args)
    elif type == DataSetType.EXTENDED_PREDICTION:
        from .extended_prediction_dataset import ExtendedPredictionDataSet
        return ExtendedPredictionDataSet(texts=texts)
    else:
        raise Exception("Unsuppoted dataset type {}".format(type))
Exemple #9
0
def create_train_dataset(cfg: CfgNode, dataset_args=None):
    gt_extension = cfg.DATASET.TRAIN.GT_EXTENSION if cfg.DATASET.TRAIN.GT_EXTENSION is not False else DataSetType.gt_extension(
        cfg.DATASET.TRAIN.TYPE)

    # Training dataset
    print("Resolving input files")
    input_image_files = sorted(glob_all(cfg.DATASET.TRAIN.PATH))
    if not cfg.DATASET.TRAIN.TEXT_FILES:
        if gt_extension:
            gt_txt_files = [
                split_all_ext(f)[0] + gt_extension for f in input_image_files
            ]
        else:
            gt_txt_files = [None] * len(input_image_files)
    else:
        gt_txt_files = sorted(glob_all(cfg.DATASET.TRAIN.TEXT_FILES))
        input_image_files, gt_txt_files = keep_files_with_same_file_name(
            input_image_files, gt_txt_files)
        for img, gt in zip(input_image_files, gt_txt_files):
            if split_all_ext(os.path.basename(img))[0] != split_all_ext(
                    os.path.basename(gt))[0]:
                raise Exception(
                    "Expected identical basenames of file: {} and {}".format(
                        img, gt))

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception(
            "Some image are occurring more than once in the data set.")

    dataset = create_dataset(
        cfg.DATASET.TRAIN.TYPE,
        DataSetMode.TRAIN,
        images=input_image_files,
        texts=gt_txt_files,
        skip_invalid=not cfg.DATALOADER.NO_SKIP_INVALID_GT,
        args=dataset_args if dataset_args else {},
    )
    print("Found {} files in the dataset".format(len(dataset)))
    return dataset
Exemple #10
0
    def create_data_reader(
        cls,
        type: Union[DataSetType, str],
        mode: PipelineMode,
        images: List[str] = None,
        texts: List[str] = None,
        skip_invalid=False,
        remove_invalid=True,
        non_existing_as_empty=False,
        args: FileDataReaderArgs = None,
    ) -> DataReader:
        if images is None:
            images = []

        if texts is None:
            texts = []

        if args is None:
            args = dict()

        if type in cls.CUSTOM_READERS:
            return cls.CUSTOM_READERS[type](
                mode=mode,
                images=images,
                texts=texts,
                skip_invalid=skip_invalid,
                remove_invalid=remove_invalid,
                non_existing_as_empty=non_existing_as_empty,
                args=args,
            )

        if DataSetType.files(type):
            if images:
                images.sort()

            if texts:
                texts.sort()

            if images and texts and len(images) > 0 and len(texts) > 0:
                images, texts = keep_files_with_same_file_name(images, texts)

        if type == DataSetType.RAW:
            from calamari_ocr.ocr.dataset.datareader.raw import RawDataReader
            return RawDataReader(mode, images, texts)

        elif type == DataSetType.FILE:
            from calamari_ocr.ocr.dataset.datareader.file import FileDataReader
            return FileDataReader(mode,
                                  images,
                                  texts,
                                  skip_invalid=skip_invalid,
                                  remove_invalid=remove_invalid,
                                  non_existing_as_empty=non_existing_as_empty)
        elif type == DataSetType.ABBYY:
            from calamari_ocr.ocr.dataset.datareader.abbyy import AbbyyReader
            return AbbyyReader(mode,
                               images,
                               texts,
                               skip_invalid=skip_invalid,
                               remove_invalid=remove_invalid,
                               non_existing_as_empty=non_existing_as_empty)
        elif type == DataSetType.PAGEXML:
            from calamari_ocr.ocr.dataset.datareader.pagexml.reader import PageXMLReader
            return PageXMLReader(mode,
                                 images,
                                 texts,
                                 skip_invalid=skip_invalid,
                                 remove_invalid=remove_invalid,
                                 non_existing_as_empty=non_existing_as_empty,
                                 args=args)
        elif type == DataSetType.HDF5:
            from calamari_ocr.ocr.dataset.datareader.hdf5 import Hdf5Reader
            return Hdf5Reader(mode, images, texts)
        elif type == DataSetType.EXTENDED_PREDICTION:
            from calamari_ocr.ocr.dataset.extended_prediction_dataset import ExtendedPredictionDataSet
            return ExtendedPredictionDataSet(texts=texts)
        elif type == DataSetType.GENERATED_LINE:
            from calamari_ocr.ocr.dataset.datareader.generated_line_dataset.dataset import GeneratedLineDataset
            return GeneratedLineDataset(mode, args=args)
        else:
            raise Exception("Unsupported dataset type {}".format(type))
Exemple #11
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--checkpoint",
                        type=str,
                        required=True,
                        help="The checkpoint used to resume")

    # validation files
    parser.add_argument("--validation",
                        type=str,
                        nargs="+",
                        help="Validation line files used for early stopping")
    parser.add_argument(
        "--validation_text_files",
        nargs="+",
        default=None,
        help=
        "Optional list of validation GT files if they are in other directory")
    parser.add_argument(
        "--validation_extension",
        default=None,
        help="Default extension of the gt files (expected to exist in same dir)"
    )
    parser.add_argument("--validation_dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)

    # input files
    parser.add_argument(
        "--files",
        nargs="+",
        help=
        "List all image files that shall be processed. Ground truth fils with the same "
        "base name but with '.gt.txt' as extension are required at the same location"
    )
    parser.add_argument(
        "--text_files",
        nargs="+",
        default=None,
        help="Optional list of GT files if they are in other directory")
    parser.add_argument(
        "--gt_extension",
        default=None,
        help="Default extension of the gt files (expected to exist in same dir)"
    )
    parser.add_argument("--dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument(
        "--no_skip_invalid_gt",
        action="store_true",
        help="Do no skip invalid gt, instead raise an exception.")

    args = parser.parse_args()

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    if args.validation_extension is None:
        args.validation_extension = DataSetType.gt_extension(
            args.validation_dataset)

    # Training dataset
    print("Resolving input files")
    input_image_files = sorted(glob_all(args.files))
    if not args.text_files:
        gt_txt_files = [
            split_all_ext(f)[0] + args.gt_extension for f in input_image_files
        ]
    else:
        gt_txt_files = sorted(glob_all(args.text_files))
        input_image_files, gt_txt_files = keep_files_with_same_file_name(
            input_image_files, gt_txt_files)
        for img, gt in zip(input_image_files, gt_txt_files):
            if split_all_ext(os.path.basename(img))[0] != split_all_ext(
                    os.path.basename(gt))[0]:
                raise Exception(
                    "Expected identical basenames of file: {} and {}".format(
                        img, gt))

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception(
            "Some image are occurring more than once in the data set.")

    dataset = create_dataset(args.dataset,
                             DataSetMode.TRAIN,
                             images=input_image_files,
                             texts=gt_txt_files,
                             skip_invalid=not args.no_skip_invalid_gt)
    print("Found {} files in the dataset".format(len(dataset)))

    # Validation dataset
    if args.validation:
        print("Resolving validation files")
        validation_image_files = glob_all(args.validation)
        if not args.validation_text_files:
            val_txt_files = [
                split_all_ext(f)[0] + args.validation_extension
                for f in validation_image_files
            ]
        else:
            val_txt_files = sorted(glob_all(args.validation_text_files))
            validation_image_files, val_txt_files = keep_files_with_same_file_name(
                validation_image_files, val_txt_files)
            for img, gt in zip(validation_image_files, val_txt_files):
                if split_all_ext(os.path.basename(img))[0] != split_all_ext(
                        os.path.basename(gt))[0]:
                    raise Exception(
                        "Expected identical basenames of validation file: {} and {}"
                        .format(img, gt))

        if len(set(val_txt_files)) != len(val_txt_files):
            raise Exception(
                "Some validation images are occurring more than once in the data set."
            )

        validation_dataset = create_dataset(
            args.validation_dataset,
            DataSetMode.TRAIN,
            images=validation_image_files,
            texts=val_txt_files,
            skip_invalid=not args.no_skip_invalid_gt)
        print("Found {} files in the validation dataset".format(
            len(validation_dataset)))
    else:
        validation_dataset = None

    print("Resuming training")
    with open(args.checkpoint + '.json', 'r') as f:
        checkpoint_params = json_format.Parse(f.read(), CheckpointParams())

        trainer = Trainer(checkpoint_params,
                          dataset,
                          validation_dataset=validation_dataset,
                          weights=args.checkpoint)
        trainer.train(progress_bar=True)
Exemple #12
0
def run(args):

    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        import json
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                if key == 'dataset' or key == 'validation_dataset':
                    setattr(args, key, DataSetType.from_string(value))
                else:
                    setattr(args, key, value)

    # parse whitelist
    whitelist = args.whitelist
    if len(whitelist) == 1:
        whitelist = list(whitelist[0])

    whitelist_files = glob_all(args.whitelist_files)
    for f in whitelist_files:
        with open(f) as txt:
            whitelist += list(txt.read())

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    if args.validation_extension is None:
        args.validation_extension = DataSetType.gt_extension(args.validation_dataset)

    if args.text_generator_params is not None:
        with open(args.text_generator_params, 'r') as f:
            args.text_generator_params = json_format.Parse(f.read(), TextGeneratorParameters())
    else:
        args.text_generator_params = TextGeneratorParameters()

    if args.line_generator_params is not None:
        with open(args.line_generator_params, 'r') as f:
            args.line_generator_params = json_format.Parse(f.read(), LineGeneratorParameters())
    else:
        args.line_generator_params = LineGeneratorParameters()

    dataset_args = {
        'line_generator_params': args.line_generator_params,
        'text_generator_params': args.text_generator_params,
        'pad': args.dataset_pad,
        'text_index': args.pagexml_text_index,
    }

    # Training dataset
    dataset = create_train_dataset(args, dataset_args)

    # Validation dataset
    if args.validation:
        print("Resolving validation files")
        validation_image_files = glob_all(args.validation)
        if not args.validation_text_files:
            val_txt_files = [split_all_ext(f)[0] + args.validation_extension for f in validation_image_files]
        else:
            val_txt_files = sorted(glob_all(args.validation_text_files))
            validation_image_files, val_txt_files = keep_files_with_same_file_name(validation_image_files, val_txt_files)
            for img, gt in zip(validation_image_files, val_txt_files):
                if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                    raise Exception("Expected identical basenames of validation file: {} and {}".format(img, gt))

        if len(set(val_txt_files)) != len(val_txt_files):
            raise Exception("Some validation images are occurring more than once in the data set.")

        validation_dataset = create_dataset(
            args.validation_dataset,
            DataSetMode.TRAIN,
            images=validation_image_files,
            texts=val_txt_files,
            skip_invalid=not args.no_skip_invalid_gt,
            args=dataset_args,
        )
        print("Found {} files in the validation dataset".format(len(validation_dataset)))
    else:
        validation_dataset = None

    params = CheckpointParams()

    params.max_iters = args.max_iters
    params.stats_size = args.stats_size
    params.batch_size = args.batch_size
    params.checkpoint_frequency = args.checkpoint_frequency if args.checkpoint_frequency >= 0 else args.early_stopping_frequency
    params.output_dir = args.output_dir
    params.output_model_prefix = args.output_model_prefix
    params.display = args.display
    params.skip_invalid_gt = not args.no_skip_invalid_gt
    params.processes = args.num_threads
    params.data_aug_retrain_on_original = not args.only_train_on_augmented

    params.early_stopping_frequency = args.early_stopping_frequency
    params.early_stopping_nbest = args.early_stopping_nbest
    params.early_stopping_best_model_prefix = args.early_stopping_best_model_prefix
    params.early_stopping_best_model_output_dir = \
        args.early_stopping_best_model_output_dir if args.early_stopping_best_model_output_dir else args.output_dir

    if args.data_preprocessing is None or len(args.data_preprocessing) == 0:
        args.data_preprocessing = [DataPreprocessorParams.DEFAULT_NORMALIZER]

    params.model.data_preprocessor.type = DataPreprocessorParams.MULTI_NORMALIZER
    for preproc in args.data_preprocessing:
        pp = params.model.data_preprocessor.children.add()
        pp.type = DataPreprocessorParams.Type.Value(preproc) if isinstance(preproc, str) else preproc
        pp.line_height = args.line_height
        pp.pad = args.pad

    # Text pre processing (reading)
    params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(params.model.text_preprocessor.children.add(), default=args.text_normalization)
    default_text_regularizer_params(params.model.text_preprocessor.children.add(), groups=args.text_regularization)
    strip_processor_params = params.model.text_preprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    # Text post processing (prediction)
    params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(params.model.text_postprocessor.children.add(), default=args.text_normalization)
    default_text_regularizer_params(params.model.text_postprocessor.children.add(), groups=args.text_regularization)
    strip_processor_params = params.model.text_postprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    if args.seed > 0:
        params.model.network.backend.random_seed = args.seed

    if args.bidi_dir:
        # change bidirectional text direction if desired
        bidi_dir_to_enum = {"rtl": TextProcessorParams.BIDI_RTL, "ltr": TextProcessorParams.BIDI_LTR,
                            "auto": TextProcessorParams.BIDI_AUTO}

        bidi_processor_params = params.model.text_preprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir]

        bidi_processor_params = params.model.text_postprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO

    params.model.line_height = args.line_height

    network_params_from_definition_string(args.network, params.model.network)
    params.model.network.clipping_mode = NetworkParams.ClippingMode.Value("CLIP_" + args.gradient_clipping_mode.upper())
    params.model.network.clipping_constant = args.gradient_clipping_const
    params.model.network.backend.fuzzy_ctc_library_path = args.fuzzy_ctc_library_path
    params.model.network.backend.num_inter_threads = args.num_inter_threads
    params.model.network.backend.num_intra_threads = args.num_intra_threads
    params.model.network.backend.shuffle_buffer_size = args.shuffle_buffer_size

    # create the actual trainer
    trainer = Trainer(params,
                      dataset,
                      validation_dataset=validation_dataset,
                      data_augmenter=SimpleDataAugmenter(),
                      n_augmentations=args.n_augmentations,
                      weights=args.weights,
                      codec_whitelist=whitelist,
                      keep_loaded_codec=args.keep_loaded_codec,
                      preload_training=not args.train_data_on_the_fly,
                      preload_validation=not args.validation_data_on_the_fly,
                      )
    trainer.train(
        auto_compute_codec=not args.no_auto_compute_codec,
        progress_bar=not args.no_progress_bars
    )
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--version',
                        action='version',
                        version='%(prog)s v' + __version__)
    parser.add_argument(
        "--files",
        nargs="+",
        help=
        "List all image files that shall be processed. Ground truth fils with the same "
        "base name but with '.gt.txt' as extension are required at the same location",
        required=True)
    parser.add_argument(
        "--text_files",
        nargs="+",
        default=None,
        help="Optional list of GT files if they are in other directory")
    parser.add_argument(
        "--gt_extension",
        default=None,
        help="Default extension of the gt files (expected to exist in same dir)"
    )
    parser.add_argument("--dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument("--line_height",
                        type=int,
                        default=48,
                        help="The line height")
    parser.add_argument("--pad",
                        type=int,
                        default=16,
                        help="Padding (left right) of the line")
    parser.add_argument("--processes",
                        type=int,
                        default=1,
                        help="The number of threads to use for all operations")

    parser.add_argument("--n_cols", type=int, default=1)
    parser.add_argument("--n_rows", type=int, default=5)
    parser.add_argument("--select", type=int, nargs="+", default=[])

    # text normalization/regularization
    parser.add_argument(
        "--n_augmentations",
        type=float,
        default=0,
        help=
        "Amount of data augmentation per line (done before training). If this number is < 1 "
        "the amount is relative.")
    parser.add_argument("--text_regularization",
                        type=str,
                        nargs="+",
                        default=["extended"],
                        help="Text regularization to apply.")
    parser.add_argument(
        "--text_normalization",
        type=str,
        default="NFC",
        help="Unicode text normalization to apply. Defaults to NFC")
    parser.add_argument("--data_preprocessing",
                        nargs="+",
                        type=DataPreprocessorParams.Type.Value,
                        choices=DataPreprocessorParams.Type.values(),
                        default=[DataPreprocessorParams.DEFAULT_NORMALIZER])

    args = parser.parse_args()

    # Text/Data processing
    if args.data_preprocessing is None or len(args.data_preprocessing) == 0:
        args.data_preprocessing = [DataPreprocessorParams.DEFAULT_NORMALIZER]

    data_preprocessor = DataPreprocessorParams()
    data_preprocessor.type = DataPreprocessorParams.MULTI_NORMALIZER
    for preproc in args.data_preprocessing:
        pp = data_preprocessor.children.add()
        pp.type = preproc
        pp.line_height = args.line_height
        pp.pad = args.pad

    # Text pre processing (reading)
    text_preprocessor = TextProcessorParams()
    text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(text_preprocessor.children.add(),
                                   default=args.text_normalization)
    default_text_regularizer_params(text_preprocessor.children.add(),
                                    groups=args.text_regularization)
    strip_processor_params = text_preprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    text_preprocessor = text_processor_from_proto(text_preprocessor)
    data_preprocessor = data_processor_from_proto(data_preprocessor)

    print("Resolving input files")
    input_image_files = sorted(glob_all(args.files))
    if not args.text_files:
        if args.gt_extension:
            gt_txt_files = [
                split_all_ext(f)[0] + args.gt_extension
                for f in input_image_files
            ]
        else:
            gt_txt_files = [None] * len(input_image_files)
    else:
        gt_txt_files = sorted(glob_all(args.text_files))
        input_image_files, gt_txt_files = keep_files_with_same_file_name(
            input_image_files, gt_txt_files)
        for img, gt in zip(input_image_files, gt_txt_files):
            if split_all_ext(os.path.basename(img))[0] != split_all_ext(
                    os.path.basename(gt))[0]:
                raise Exception(
                    "Expected identical basenames of file: {} and {}".format(
                        img, gt))

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception(
            "Some image are occurring more than once in the data set.")

    dataset = create_dataset(
        args.dataset,
        DataSetMode.TRAIN,
        images=input_image_files,
        texts=gt_txt_files,
        non_existing_as_empty=True,
    )

    if len(args.select) == 0:
        args.select = range(len(dataset.samples()))
        dataset._samples = dataset.samples()
    else:
        dataset._samples = [dataset.samples()[i] for i in args.select]

    samples = dataset.samples()

    print("Found {} files in the dataset".format(len(dataset)))

    with StreamingInputDataset(
            dataset,
            data_preprocessor,
            text_preprocessor,
            SimpleDataAugmenter(),
            args.n_augmentations,
    ) as input_dataset:
        f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all')
        row, col = 0, 0
        for i, (id, sample) in enumerate(
                zip(args.select, input_dataset.generator(args.processes))):
            line, text, params = sample
            if args.n_cols == 1:
                ax[row].imshow(line.transpose())
                ax[row].set_title("ID: {}\n{}".format(id, text))
            else:
                ax[row, col].imshow(line.transpose())
                ax[row, col].set_title("ID: {}\n{}".format(id, text))

            row += 1
            if row == args.n_rows:
                row = 0
                col += 1

            if col == args.n_cols or i == len(samples) - 1:
                plt.show()
                f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all')
                row, col = 0, 0
Exemple #14
0
def run(args):

    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        import json
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    # parse whitelist
    whitelist = args.whitelist
    if len(whitelist) == 1:
        whitelist = list(whitelist[0])

    whitelist_files = glob_all(args.whitelist_files)
    for f in whitelist_files:
        with open(f) as txt:
            whitelist += list(txt.read())

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    if args.validation_extension is None:
        args.validation_extension = DataSetType.gt_extension(args.validation_dataset)

    # Training dataset
    print("Resolving input files")
    input_image_files = sorted(glob_all(args.files))
    if not args.text_files:
        gt_txt_files = [split_all_ext(f)[0] + args.gt_extension for f in input_image_files]
    else:
        gt_txt_files = sorted(glob_all(args.text_files))
        input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files)
        for img, gt in zip(input_image_files, gt_txt_files):
            if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                raise Exception("Expected identical basenames of file: {} and {}".format(img, gt))

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception("Some image are occurring more than once in the data set.")

    dataset = create_dataset(
        args.dataset,
        DataSetMode.TRAIN,
        images=input_image_files,
        texts=gt_txt_files,
        skip_invalid=not args.no_skip_invalid_gt
    )
    print("Found {} files in the dataset".format(len(dataset)))

    # Validation dataset
    if args.validation:
        print("Resolving validation files")
        validation_image_files = glob_all(args.validation)
        if not args.validation_text_files:
            val_txt_files = [split_all_ext(f)[0] + args.validation_extension for f in validation_image_files]
        else:
            val_txt_files = sorted(glob_all(args.validation_text_files))
            validation_image_files, val_txt_files = keep_files_with_same_file_name(validation_image_files, val_txt_files)
            for img, gt in zip(validation_image_files, val_txt_files):
                if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                    raise Exception("Expected identical basenames of validation file: {} and {}".format(img, gt))

        if len(set(val_txt_files)) != len(val_txt_files):
            raise Exception("Some validation images are occurring more than once in the data set.")

        validation_dataset = create_dataset(
            args.validation_dataset,
            DataSetMode.TRAIN,
            images=validation_image_files,
            texts=val_txt_files,
            skip_invalid=not args.no_skip_invalid_gt)
        print("Found {} files in the validation dataset".format(len(validation_dataset)))
    else:
        validation_dataset = None

    params = CheckpointParams()

    params.max_iters = args.max_iters
    params.stats_size = args.stats_size
    params.batch_size = args.batch_size
    params.checkpoint_frequency = args.checkpoint_frequency if args.checkpoint_frequency >= 0 else args.early_stopping_frequency
    params.output_dir = args.output_dir
    params.output_model_prefix = args.output_model_prefix
    params.display = args.display
    params.skip_invalid_gt = not args.no_skip_invalid_gt
    params.processes = args.num_threads
    params.data_aug_retrain_on_original = not args.only_train_on_augmented

    params.early_stopping_frequency = args.early_stopping_frequency
    params.early_stopping_nbest = args.early_stopping_nbest
    params.early_stopping_best_model_prefix = args.early_stopping_best_model_prefix
    params.early_stopping_best_model_output_dir = \
        args.early_stopping_best_model_output_dir if args.early_stopping_best_model_output_dir else args.output_dir

    params.model.data_preprocessor.type = DataPreprocessorParams.DEFAULT_NORMALIZER
    params.model.data_preprocessor.line_height = args.line_height
    params.model.data_preprocessor.pad = args.pad

    # Text pre processing (reading)
    params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(params.model.text_preprocessor.children.add(), default=args.text_normalization)
    default_text_regularizer_params(params.model.text_preprocessor.children.add(), groups=args.text_regularization)
    strip_processor_params = params.model.text_preprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    # Text post processing (prediction)
    params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(params.model.text_postprocessor.children.add(), default=args.text_normalization)
    default_text_regularizer_params(params.model.text_postprocessor.children.add(), groups=args.text_regularization)
    strip_processor_params = params.model.text_postprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    if args.seed > 0:
        params.model.network.backend.random_seed = args.seed

    if args.bidi_dir:
        # change bidirectional text direction if desired
        bidi_dir_to_enum = {"rtl": TextProcessorParams.BIDI_RTL, "ltr": TextProcessorParams.BIDI_LTR,
                            "auto": TextProcessorParams.BIDI_AUTO}

        bidi_processor_params = params.model.text_preprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir]

        bidi_processor_params = params.model.text_postprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO

    params.model.line_height = args.line_height

    network_params_from_definition_string(args.network, params.model.network)
    params.model.network.clipping_mode = NetworkParams.ClippingMode.Value("CLIP_" + args.gradient_clipping_mode.upper())
    params.model.network.clipping_constant = args.gradient_clipping_const
    params.model.network.backend.fuzzy_ctc_library_path = args.fuzzy_ctc_library_path
    params.model.network.backend.num_inter_threads = args.num_inter_threads
    params.model.network.backend.num_intra_threads = args.num_intra_threads

    # create the actual trainer
    trainer = Trainer(params,
                      dataset,
                      validation_dataset=validation_dataset,
                      data_augmenter=SimpleDataAugmenter(),
                      n_augmentations=args.n_augmentations,
                      weights=args.weights,
                      codec_whitelist=whitelist,
                      preload_training=not args.train_data_on_the_fly,
                      preload_validation=not args.validation_data_on_the_fly,
                      )
    trainer.train(
        auto_compute_codec=not args.no_auto_compute_codec,
        progress_bar=not args.no_progress_bars
    )
Exemple #15
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--checkpoint", type=str, required=True,
                        help="The checkpoint used to resume")

    # validation files
    parser.add_argument("--validation", type=str, nargs="+",
                        help="Validation line files used for early stopping")
    parser.add_argument("--validation_text_files", nargs="+", default=None,
                        help="Optional list of validation GT files if they are in other directory")
    parser.add_argument("--validation_extension", default=None,
                        help="Default extension of the gt files (expected to exist in same dir)")
    parser.add_argument("--validation_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)

    # input files
    parser.add_argument("--files", nargs="+",
                        help="List all image files that shall be processed. Ground truth fils with the same "
                             "base name but with '.gt.txt' as extension are required at the same location")
    parser.add_argument("--text_files", nargs="+", default=None,
                        help="Optional list of GT files if they are in other directory")
    parser.add_argument("--gt_extension", default=None,
                        help="Default extension of the gt files (expected to exist in same dir)")
    parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--no_skip_invalid_gt", action="store_true",
                        help="Do no skip invalid gt, instead raise an exception.")

    args = parser.parse_args()

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    if args.validation_extension is None:
        args.validation_extension = DataSetType.gt_extension(args.validation_dataset)

    # Training dataset
    print("Resolving input files")
    input_image_files = sorted(glob_all(args.files))
    if not args.text_files:
        gt_txt_files = [split_all_ext(f)[0] + args.gt_extension for f in input_image_files]
    else:
        gt_txt_files = sorted(glob_all(args.text_files))
        input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files)
        for img, gt in zip(input_image_files, gt_txt_files):
            if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                raise Exception("Expected identical basenames of file: {} and {}".format(img, gt))

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception("Some image are occurring more than once in the data set.")

    dataset = create_dataset(
        args.dataset,
        DataSetMode.TRAIN,
        images=input_image_files,
        texts=gt_txt_files,
        skip_invalid=not args.no_skip_invalid_gt
    )
    print("Found {} files in the dataset".format(len(dataset)))

    # Validation dataset
    if args.validation:
        print("Resolving validation files")
        validation_image_files = glob_all(args.validation)
        if not args.validation_text_files:
            val_txt_files = [split_all_ext(f)[0] + args.validation_extension for f in validation_image_files]
        else:
            val_txt_files = sorted(glob_all(args.validation_text_files))
            validation_image_files, val_txt_files = keep_files_with_same_file_name(validation_image_files, val_txt_files)
            for img, gt in zip(validation_image_files, val_txt_files):
                if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                    raise Exception("Expected identical basenames of validation file: {} and {}".format(img, gt))

        if len(set(val_txt_files)) != len(val_txt_files):
            raise Exception("Some validation images are occurring more than once in the data set.")

        validation_dataset = create_dataset(
            args.validation_dataset,
            DataSetMode.TRAIN,
            images=validation_image_files,
            texts=val_txt_files,
            skip_invalid=not args.no_skip_invalid_gt)
        print("Found {} files in the validation dataset".format(len(validation_dataset)))
    else:
        validation_dataset = None

    print("Resuming training")
    with open(args.checkpoint + '.json', 'r') as f:
        checkpoint_params = json_format.Parse(f.read(), CheckpointParams())

        trainer = Trainer(checkpoint_params, dataset,
                          validation_dataset=validation_dataset,
                          weights=args.checkpoint)
        trainer.train(progress_bar=True)