Exemplo n.º 1
0
    def __init__(
        self,
        mode: PipelineMode,
        files: List[str] = None,
        xmlfiles: List[str] = None,
        skip_invalid=False,
        remove_invalid=True,
        binary=False,
        non_existing_as_empty=False,
    ):
        """ Create a dataset from a Path as String

        Parameters
         ----------
        files : [], required
            image files
        skip_invalid : bool, optional
            skip invalid files
        remove_invalid : bool, optional
            remove invalid files
        """

        super().__init__(mode, skip_invalid, remove_invalid)

        self.xmlfiles = xmlfiles if xmlfiles else []
        self.files = files if files else []

        self._non_existing_as_empty = non_existing_as_empty
        if len(self.xmlfiles) == 0:
            from calamari_ocr.ocr.dataset import DataSetType
            self.xmlfiles = [
                split_all_ext(p)[0] +
                DataSetType.gt_extension(DataSetType.ABBYY) for p in files
            ]

        if len(self.files) == 0:
            self.files = [None] * len(self.xmlfiles)

        self.book = XMLReader(self.files, self.xmlfiles, skip_invalid,
                              remove_invalid).read()
        self.binary = binary

        for p, page in enumerate(self.book.pages):
            for l, line in enumerate(page.getLines()):
                for f, fo in enumerate(line.formats):
                    self.add_sample({
                        "image_path":
                        page.imgFile,
                        "xml_path":
                        page.xmlFile,
                        "id":
                        "{}_{}_{}_{}".format(
                            os.path.splitext(page.xmlFile if page.
                                             xmlFile else page.imgFile)[0], p,
                            l, f),
                        "line":
                        line,
                        "format":
                        fo,
                    })
Exemplo n.º 2
0
    def finish_chunck(self):
        if len(self.text) == 0:
            return

        codec = self.compute_codec()

        filename = "{}_{:03d}{}".format(self.output_filename, self.current_chunk, DataSetType.gt_extension(DataSetType.HDF5))
        self.files.append(filename)
        file = h5py.File(filename, 'w')
        dti32 = h5py.special_dtype(vlen=np.dtype('int32'))
        dtui8 = h5py.special_dtype(vlen=np.dtype('uint8'))
        file.create_dataset('transcripts', (len(self.text),), dtype=dti32, compression='gzip')
        file.create_dataset('images_dims', data=[d.shape for d in self.data], dtype=int)
        file.create_dataset('images', (len(self.text),), dtype=dtui8, compression='gzip')
        file.create_dataset('codec', data=list(map(ord, codec)))
        file['transcripts'][...] = [list(map(codec.index, d)) for d in self.text]
        file['images'][...] = [d.reshape(-1) for d in self.data]
        file.close()

        self.current_chunk += 1
        self.data = []
        self.text = []
Exemplo n.º 3
0
def run(args):
    # local imports (to prevent tensorflow from being imported to early)
    from calamari_ocr.ocr.scenario import Scenario
    from calamari_ocr.ocr.dataset.data import Data

    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                if key == 'dataset' or key == 'validation_dataset':
                    setattr(args, key, DataSetType.from_string(value))
                else:
                    setattr(args, key, value)

    check_train_args(args)

    if args.output_dir is not None:
        args.output_dir = os.path.abspath(args.output_dir)
        setup_log(args.output_dir, append=False)

    # parse whitelist
    whitelist = args.whitelist
    if len(whitelist) == 1:
        whitelist = list(whitelist[0])

    whitelist_files = glob_all(args.whitelist_files)
    for f in whitelist_files:
        with open(f) as txt:
            whitelist += list(txt.read())

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    if args.validation_extension is None:
        args.validation_extension = DataSetType.gt_extension(
            args.validation_dataset)

    if args.dataset == DataSetType.GENERATED_LINE or args.validation_dataset == DataSetType.GENERATED_LINE:
        if args.text_generator_params is not None:
            with open(args.text_generator_params, 'r') as f:
                args.text_generator_params = TextGeneratorParams.from_json(
                    f.read())
        else:
            args.text_generator_params = TextGeneratorParams()

        if args.line_generator_params is not None:
            with open(args.line_generator_params, 'r') as f:
                args.line_generator_params = LineGeneratorParams.from_json(
                    f.read())
        else:
            args.line_generator_params = LineGeneratorParams()

    dataset_args = FileDataReaderArgs(
        line_generator_params=args.line_generator_params,
        text_generator_params=args.text_generator_params,
        pad=args.dataset_pad,
        text_index=args.pagexml_text_index,
    )

    params: TrainerParams = Scenario.default_trainer_params()

    # =================================================================================================================
    # Data Params
    data_params: DataParams = params.scenario_params.data_params
    data_params.train = PipelineParams(
        type=args.dataset,
        skip_invalid=not args.no_skip_invalid_gt,
        remove_invalid=True,
        files=args.files,
        text_files=args.text_files,
        gt_extension=args.gt_extension,
        data_reader_args=dataset_args,
        batch_size=args.batch_size,
        num_processes=args.num_threads,
    )
    if args.validation_split_ratio:
        if args.validation is not None:
            raise ValueError("Set either validation_split_ratio or validation")
        if not 0 < args.validation_split_ratio < 1:
            raise ValueError("validation_split_ratio must be in (0, 1)")

        # resolve all files so we can split them
        data_params.train = data_params.train.prepare_for_mode(
            PipelineMode.Training)
        n = int(args.validation_split_ratio * len(data_params.train.files))
        logger.info(
            f"Splitting training and validation files with ratio {args.validation_split_ratio}: "
            f"{n}/{len(data_params.train.files) - n} for validation/training.")
        indices = list(range(len(data_params.train.files)))
        shuffle(indices)
        all_files = data_params.train.files
        all_text_files = data_params.train.text_files

        # split train and val img/gt files. Use train settings
        data_params.train.files = [all_files[i] for i in indices[:n]]
        if all_text_files is not None:
            assert (len(all_text_files) == len(all_files))
            data_params.train.text_files = [
                all_text_files[i] for i in indices[:n]
            ]
        data_params.val = PipelineParams(
            type=args.dataset,
            skip_invalid=not args.no_skip_invalid_gt,
            remove_invalid=True,
            files=[all_files[i] for i in indices[n:]],
            text_files=[all_text_files[i] for i in indices[n:]]
            if data_params.train.text_files is not None else None,
            gt_extension=args.gt_extension,
            data_reader_args=dataset_args,
            batch_size=args.batch_size,
            num_processes=args.num_threads,
        )
    elif args.validation:
        data_params.val = PipelineParams(
            type=args.validation_dataset,
            files=args.validation,
            text_files=args.validation_text_files,
            skip_invalid=not args.no_skip_invalid_gt,
            gt_extension=args.validation_extension,
            data_reader_args=dataset_args,
            batch_size=args.batch_size,
            num_processes=args.num_threads,
        )
    else:
        data_params.val = None

    data_params.pre_processors_ = SamplePipelineParams(run_parallel=True)
    data_params.post_processors_.run_parallel = SamplePipelineParams(
        run_parallel=False,
        sample_processors=[
            DataProcessorFactoryParams(ReshapeOutputsProcessor.__name__),
            DataProcessorFactoryParams(CTCDecoderProcessor.__name__),
        ])

    for p in args.data_preprocessing:
        p_p = Data.data_processor_factory().processors[p].default_params()
        if 'pad' in p_p:
            p_p['pad'] = args.pad
        data_params.pre_processors_.sample_processors.append(
            DataProcessorFactoryParams(p, INPUT_PROCESSOR, p_p))

    # Text pre processing (reading)
    data_params.pre_processors_.sample_processors.extend([
        DataProcessorFactoryParams(
            TextNormalizer.__name__, TARGETS_PROCESSOR,
            {'unicode_normalization': args.text_normalization}),
        DataProcessorFactoryParams(
            TextRegularizer.__name__, TARGETS_PROCESSOR, {
                'replacements':
                default_text_regularizer_replacements(args.text_regularization)
            }),
        DataProcessorFactoryParams(StripTextProcessor.__name__,
                                   TARGETS_PROCESSOR)
    ])

    # Text post processing (prediction)
    data_params.post_processors_.sample_processors.extend([
        DataProcessorFactoryParams(
            TextNormalizer.__name__, TARGETS_PROCESSOR,
            {'unicode_normalization': args.text_normalization}),
        DataProcessorFactoryParams(
            TextRegularizer.__name__, TARGETS_PROCESSOR, {
                'replacements':
                default_text_regularizer_replacements(args.text_regularization)
            }),
        DataProcessorFactoryParams(StripTextProcessor.__name__,
                                   TARGETS_PROCESSOR)
    ])
    if args.bidi_dir:
        data_params.pre_processors_.sample_processors.append(
            DataProcessorFactoryParams(BidiTextProcessor.__name__,
                                       TARGETS_PROCESSOR,
                                       {'bidi_direction': args.bidi_dir}))
        data_params.post_processors_.sample_processors.append(
            DataProcessorFactoryParams(BidiTextProcessor.__name__,
                                       TARGETS_PROCESSOR,
                                       {'bidi_direction': args.bidi_dir}))

    data_params.pre_processors_.sample_processors.extend([
        DataProcessorFactoryParams(AugmentationProcessor.__name__,
                                   {PipelineMode.Training},
                                   {'augmenter_type': 'simple'}),
        DataProcessorFactoryParams(PrepareSampleProcessor.__name__,
                                   INPUT_PROCESSOR),
    ])

    data_params.data_aug_params = DataAugmentationAmount.from_factor(
        args.n_augmentations)
    data_params.line_height_ = args.line_height

    # =================================================================================================================
    # Trainer Params
    params.calc_ema = args.ema_weights
    params.verbose = args.train_verbose
    params.force_eager = args.debug
    params.skip_model_load_test = not args.debug
    params.scenario_params.debug_graph_construction = args.debug
    params.epochs = args.epochs
    params.samples_per_epoch = int(
        args.samples_per_epoch) if args.samples_per_epoch >= 1 else -1
    params.scale_epoch_size = abs(
        args.samples_per_epoch) if args.samples_per_epoch < 1 else 1
    params.skip_load_model_test = True
    params.scenario_params.export_frozen = False
    params.checkpoint_save_freq_ = args.checkpoint_frequency if args.checkpoint_frequency >= 0 else args.early_stopping_frequency
    params.checkpoint_dir = args.output_dir
    params.test_every_n = args.display
    params.skip_invalid_gt = not args.no_skip_invalid_gt
    params.data_aug_retrain_on_original = not args.only_train_on_augmented
    params.use_training_as_validation = args.use_train_as_val
    if args.seed > 0:
        params.random_seed = args.seed

    params.optimizer_params.clip_grad = args.gradient_clipping_norm
    params.codec_whitelist = whitelist
    params.keep_loaded_codec = args.keep_loaded_codec
    params.preload_training = not args.train_data_on_the_fly
    params.preload_validation = not args.validation_data_on_the_fly
    params.warmstart_params.model = args.weights

    params.auto_compute_codec = not args.no_auto_compute_codec
    params.progress_bar = not args.no_progress_bars

    params.early_stopping_params.frequency = args.early_stopping_frequency
    params.early_stopping_params.upper_threshold = 0.9
    params.early_stopping_params.lower_threshold = 1.0 - args.early_stopping_at_accuracy
    params.early_stopping_params.n_to_go = args.early_stopping_nbest
    params.early_stopping_params.best_model_name = ''
    params.early_stopping_params.best_model_output_dir = args.early_stopping_best_model_output_dir
    params.scenario_params.default_serve_dir_ = f'{args.early_stopping_best_model_prefix}.ckpt.h5'
    params.scenario_params.trainer_params_filename_ = f'{args.early_stopping_best_model_prefix}.ckpt.json'

    # =================================================================================================================
    # Model params
    params_from_definition_string(args.network, params)
    params.scenario_params.model_params.ensemble = args.ensemble
    params.scenario_params.model_params.masking_mode = args.masking_mode

    scenario = Scenario(params.scenario_params)
    trainer = scenario.create_trainer(params)
    trainer.train()
Exemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--version',
                        action='version',
                        version='%(prog)s v' + __version__)
    parser.add_argument(
        "--files",
        nargs="+",
        help=
        "List all image files that shall be processed. Ground truth fils with the same "
        "base name but with '.gt.txt' as extension are required at the same location",
        required=True)
    parser.add_argument(
        "--text_files",
        nargs="+",
        default=None,
        help="Optional list of GT files if they are in other directory")
    parser.add_argument(
        "--gt_extension",
        default=None,
        help="Default extension of the gt files (expected to exist in same dir)"
    )
    parser.add_argument("--dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument("--line_height",
                        type=int,
                        default=48,
                        help="The line height")
    parser.add_argument("--pad",
                        type=int,
                        default=16,
                        help="Padding (left right) of the line")
    parser.add_argument("--processes",
                        type=int,
                        default=1,
                        help="The number of threads to use for all operations")

    parser.add_argument("--n_cols", type=int, default=1)
    parser.add_argument("--n_rows", type=int, default=5)
    parser.add_argument("--select", type=int, nargs="+", default=[])

    # text normalization/regularization
    parser.add_argument(
        "--n_augmentations",
        type=float,
        default=0,
        help=
        "Amount of data augmentation per line (done before training). If this number is < 1 "
        "the amount is relative.")
    parser.add_argument("--text_regularization",
                        type=str,
                        nargs="+",
                        default=["extended"],
                        help="Text regularization to apply.")
    parser.add_argument(
        "--text_normalization",
        type=str,
        default="NFC",
        help="Unicode text normalization to apply. Defaults to NFC")
    parser.add_argument(
        "--data_preprocessing",
        nargs="+",
        type=str,
        choices=[
            k for k, p in Data.data_processor_factory().processors.items()
            if issubclass(p, ImageProcessor)
        ],
        default=[p.name for p in default_image_processors()])
    parser.add_argument(
        "--bidi_dir",
        type=str,
        default=None,
        choices=["ltr", "rtl", "auto"],
        help=
        "The default text direction when preprocessing bidirectional text. Supported values "
        "are 'auto' to automatically detect the direction, 'ltr' and 'rtl' for left-to-right and "
        "right-to-left, respectively")

    parser.add_argument("--preload",
                        action='store_true',
                        help='Simulate preloading')
    parser.add_argument("--as_validation",
                        action='store_true',
                        help="Access as validation instead of training data.")

    args = parser.parse_args()

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    dataset_args = FileDataReaderArgs(pad=args.pad, )

    data_params: DataParams = Data.get_default_params()
    data_params.train = PipelineParams(
        type=args.dataset,
        remove_invalid=True,
        files=args.files,
        text_files=args.text_files,
        gt_extension=args.gt_extension,
        data_reader_args=dataset_args,
        batch_size=1,
        num_processes=args.processes,
    )
    data_params.val = data_params.train

    data_params.pre_processors_ = SamplePipelineParams(run_parallel=True)
    data_params.post_processors_ = SamplePipelineParams(run_parallel=True)
    for p in args.data_preprocessing:
        p_p = Data.data_processor_factory().processors[p].default_params()
        if 'pad' in p_p:
            p_p['pad'] = args.pad
        data_params.pre_processors_.sample_processors.append(
            DataProcessorFactoryParams(p, INPUT_PROCESSOR, p_p))

    # Text pre processing (reading)
    data_params.pre_processors_.sample_processors.extend([
        DataProcessorFactoryParams(
            TextNormalizer.__name__, TARGETS_PROCESSOR,
            {'unicode_normalization': args.text_normalization}),
        DataProcessorFactoryParams(
            TextRegularizer.__name__, TARGETS_PROCESSOR, {
                'replacements':
                default_text_regularizer_replacements(args.text_regularization)
            }),
        DataProcessorFactoryParams(StripTextProcessor.__name__,
                                   TARGETS_PROCESSOR)
    ])

    # Text post processing (prediction)
    data_params.post_processors_.sample_processors.extend([
        DataProcessorFactoryParams(
            TextNormalizer.__name__, TARGETS_PROCESSOR,
            {'unicode_normalization': args.text_normalization}),
        DataProcessorFactoryParams(
            TextRegularizer.__name__, TARGETS_PROCESSOR, {
                'replacements':
                default_text_regularizer_replacements(args.text_regularization)
            }),
        DataProcessorFactoryParams(StripTextProcessor.__name__,
                                   TARGETS_PROCESSOR)
    ])
    if args.bidi_dir:
        data_params.pre_processors_.sample_processors.append(
            DataProcessorFactoryParams(BidiTextProcessor.__name__,
                                       TARGETS_PROCESSOR,
                                       {'bidi_direction': args.bidi_dir}))
        data_params.post_processors_.sample_processors.append(
            DataProcessorFactoryParams(BidiTextProcessor.__name__,
                                       TARGETS_PROCESSOR,
                                       {'bidi_direction': args.bidi_dir}))

    data_params.pre_processors_.sample_processors.extend([
        DataProcessorFactoryParams(AugmentationProcessor.__name__,
                                   {PipelineMode.Training},
                                   {'augmenter_type': 'simple'}),
        # DataProcessorFactoryParams(PrepareSampleProcessor.__name__),  # NOT THIS, since, we want to access raw input
    ])

    data_params.data_aug_params = DataAugmentationAmount.from_factor(
        args.n_augmentations)
    data_params.line_height_ = args.line_height

    data = Data(data_params)
    data_pipeline = data.get_val_data(
    ) if args.as_validation else data.get_train_data()
    if not args.preload:
        reader: DataReader = data_pipeline.reader()
        if len(args.select) == 0:
            args.select = range(len(reader))
        else:
            reader._samples = [reader.samples()[i] for i in args.select]
    else:
        data.preload()
        data_pipeline = data.get_val_data(
        ) if args.as_validation else data.get_train_data()
        samples = data_pipeline.samples
        if len(args.select) == 0:
            args.select = range(len(samples))
        else:
            data_pipeline.samples = [samples[i] for i in args.select]

    f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all')
    row, col = 0, 0
    with data_pipeline as dataset:
        for i, (id, sample) in enumerate(
                zip(args.select,
                    dataset.generate_input_samples(auto_repeat=False))):
            line, text, params = sample
            if args.n_cols == 1:
                ax[row].imshow(line.transpose())
                ax[row].set_title("ID: {}\n{}".format(id, text))
            else:
                ax[row, col].imshow(line.transpose())
                ax[row, col].set_title("ID: {}\n{}".format(id, text))

            row += 1
            if row == args.n_rows:
                row = 0
                col += 1

            if col == args.n_cols or i == len(dataset) - 1:
                plt.show()
                f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all')
                row, col = 0, 0