Ejemplo n.º 1
0
 def __init__(self):
     self.dataset = DataSetType.FILE
     self.gt_extension = DataSetType.gt_extension(self.dataset)
     self.files = glob_all(
         [os.path.join(this_dir, "data", "uw3_50lines", "train", "*.png")])
     self.seed = 24
     self.backend = "tensorflow"
     self.network = "cnn=40:3x3,pool=2x2,cnn=60:3x3,pool=2x2,lstm=200,dropout=0.5"
     self.line_height = 48
     self.pad = 16
     self.num_threads = 1
     self.display = 1
     self.batch_size = 1
     self.checkpoint_frequency = 1000
     self.epochs = 1
     self.samples_per_epoch = 8
     self.stats_size = 100
     self.no_skip_invalid_gt = False
     self.no_progress_bars = True
     self.output_dir = None
     self.output_model_prefix = "uw3_50lines"
     self.bidi_dir = None
     self.weights = None
     self.ema_weights = False
     self.whitelist_files = []
     self.whitelist = []
     self.gradient_clipping_norm = 5
     self.validation = None
     self.validation_dataset = DataSetType.FILE
     self.validation_extension = None
     self.validation_split_ratio = None
     self.early_stopping_frequency = -1
     self.early_stopping_nbest = 10
     self.early_stopping_at_accuracy = 0.99
     self.early_stopping_best_model_prefix = "uw3_50lines_best"
     self.early_stopping_best_model_output_dir = self.output_dir
     self.n_augmentations = 0
     self.num_inter_threads = 0
     self.num_intra_threads = 0
     self.text_regularization = ["extended"]
     self.text_normalization = "NFC"
     self.text_generator_params = None
     self.line_generator_params = None
     self.pagexml_text_index = 0
     self.text_files = None
     self.only_train_on_augmented = False
     self.data_preprocessing = [p.name for p in default_image_processors()]
     self.shuffle_buffer_size = 1000
     self.keep_loaded_codec = False
     self.train_data_on_the_fly = False
     self.validation_data_on_the_fly = False
     self.no_auto_compute_codec = False
     self.dataset_pad = 0
     self.debug = False
     self.train_verbose = True
     self.use_train_as_val = False
     self.ensemble = -1
     self.masking_mode = 1
Ejemplo n.º 2
0
 def default_params(cls) -> DataParams:
     params: DataParams = super(Data, cls).default_params()
     params.pre_proc = SequentialProcessorPipelineParams(
         run_parallel=True,
         processors=default_image_processors() +
         default_text_pre_processors() + [
             AugmentationProcessorParams(modes={PipelineMode.TRAINING}),
             PrepareSampleProcessorParams(modes=INPUT_PROCESSOR),
         ],
     )
     params.post_proc = SequentialProcessorPipelineParams(
         run_parallel=True,
         processors=[
             ReshapeOutputsProcessorParams(),
             CTCDecoderProcessorParams(),
         ] + default_text_pre_processors())
     return params
Ejemplo n.º 3
0
 def get_default_params(cls) -> DataParams:
     params: DataParams = super(Data, cls).get_default_params()
     params.pre_processors_ = SamplePipelineParams(
         run_parallel=True,
         sample_processors=default_image_processors() +
         default_text_pre_processors() + [
             DataProcessorFactoryParams(AugmentationProcessor.__name__,
                                        {PipelineMode.Training}),
             DataProcessorFactoryParams(PrepareSampleProcessor.__name__,
                                        INPUT_PROCESSOR),
         ],
     )
     params.post_processors_ = SamplePipelineParams(
         run_parallel=True,
         sample_processors=[
             DataProcessorFactoryParams(ReshapeOutputsProcessor.__name__),
             DataProcessorFactoryParams(CTCDecoderProcessor.__name__),
         ] + default_text_pre_processors())
     return params
Ejemplo n.º 4
0
    this_dir = os.path.dirname(os.path.realpath(__file__))
    base_path = os.path.abspath(os.path.join(this_dir, "..", "..", "test", "data", "uw3_50lines", "train"))

    fdr = FileDataParams(
        num_processes=8,
        images=[os.path.join(base_path, "*.png")],
        limit=1000,
    )

    params = DataParams(
        codec=Codec("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,:;-?+=_()*{}[]`@#$%^&'\""),
        downscale_factor=4,
        line_height=48,
        pre_proc=SequentialProcessorPipelineParams(
            run_parallel=True,
            processors=default_image_processors()
            + default_text_pre_processors()
            + [
                AugmentationProcessorParams(
                    modes={PipelineMode.TRAINING},
                    data_aug_params=DataAugmentationAmount(amount=2),
                ),
                PrepareSampleProcessorParams(modes=INPUT_PROCESSOR),
            ],
        ),
        post_proc=SequentialProcessorPipelineParams(run_parallel=False),
        train=fdr,
        val=fdr,
        input_channels=1,
    )
    params = DataParams.from_json(params.to_json())
Ejemplo n.º 5
0
def setup_train_args(parser, omit=None):
    # required params for args
    from calamari_ocr.ocr.dataset.data import Data

    if omit is None:
        omit = []

    parser.add_argument('--version',
                        action='version',
                        version='%(prog)s v' + __version__)

    if "files" not in omit:
        parser.add_argument(
            "--files",
            nargs="+",
            default=[],
            help=
            "List all image files that shall be processed. Ground truth fils with the same "
            "base name but with '.gt.txt' as extension are required at the same location"
        )
        parser.add_argument(
            "--text_files",
            nargs="+",
            default=None,
            help="Optional list of GT files if they are in other directory")
        parser.add_argument(
            "--gt_extension",
            default=None,
            help=
            "Default extension of the gt files (expected to exist in same dir)"
        )
        parser.add_argument("--dataset",
                            type=DataSetType.from_string,
                            choices=list(DataSetType),
                            default=DataSetType.FILE)

    parser.add_argument(
        "--train_data_on_the_fly",
        action='store_true',
        default=False,
        help=
        'Instead of preloading all data during the training, load the data on the fly. '
        'This is slower, but might be required for limited RAM or large dataset'
    )

    parser.add_argument(
        "--seed",
        type=int,
        default="0",
        help=
        "Seed for random operations. If negative or zero a 'random' seed is used"
    )
    parser.add_argument(
        "--network",
        type=str,
        default="cnn=40:3x3,pool=2x2,cnn=60:3x3,pool=2x2,lstm=200,dropout=0.5",
        help="The network structure")
    parser.add_argument("--line_height",
                        type=int,
                        default=48,
                        help="The line height")
    parser.add_argument("--pad",
                        type=int,
                        default=16,
                        help="Padding (left right) of the line")
    parser.add_argument("--num_threads",
                        type=int,
                        default=1,
                        help="The number of threads to use for all operations")
    parser.add_argument(
        "--display",
        type=int,
        default=1,
        help=
        "Frequency of how often an output shall occur during training [epochs]"
    )
    parser.add_argument("--batch_size",
                        type=int,
                        default=1,
                        help="The batch size to use for training")
    parser.add_argument("--ema_weights",
                        action="store_true",
                        default=False,
                        help="Use exponentially averaged weights")
    parser.add_argument(
        "--checkpoint_frequency",
        type=int,
        default=-1,
        help=
        "The frequency how often to write checkpoints during training [epochs]"
        "If -1 (default), the early_stopping_frequency will be used. If 0 no checkpoints are written"
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=100,
        help="The number of iterations for training. "
        "If using early stopping, this is the maximum number of iterations")
    parser.add_argument(
        "--samples_per_epoch",
        type=float,
        default=-1,
        help=
        "The number of samples to process per epoch. By default the size of the training dataset."
        "If in (0,1) it is relative to the dataset size")
    parser.add_argument(
        "--early_stopping_at_accuracy",
        type=float,
        default=1.0,
        help="Stop training if the early stopping accuracy reaches this value")
    parser.add_argument(
        "--no_skip_invalid_gt",
        action="store_true",
        help="Do no skip invalid gt, instead raise an exception.")
    parser.add_argument("--no_progress_bars",
                        action="store_true",
                        help="Do not show any progress bars")

    if "output_dir" not in omit:
        parser.add_argument(
            "--output_dir",
            type=str,
            default="",
            help="Default directory where to store checkpoints and models")
    if "output_model_prefix" not in omit:
        parser.add_argument("--output_model_prefix",
                            type=str,
                            default="model_",
                            help="Prefix for storing checkpoints and models")

    parser.add_argument(
        "--bidi_dir",
        type=str,
        default=None,
        choices=["ltr", "rtl", "auto"],
        help=
        "The default text direction when preprocessing bidirectional text. Supported values "
        "are 'auto' to automatically detect the direction, 'ltr' and 'rtl' for left-to-right and "
        "right-to-left, respectively")

    if "weights" not in omit:
        parser.add_argument("--weights",
                            type=str,
                            default=None,
                            help="Load network weights from the given file.")

    parser.add_argument(
        "--no_auto_compute_codec",
        action='store_true',
        default=False,
        help="Do not compute the codec automatically. See also whitelist")
    parser.add_argument(
        "--whitelist_files",
        type=str,
        nargs="+",
        default=[],
        help=
        "Whitelist of txt files that may not be removed on restoring a model")
    parser.add_argument(
        "--whitelist",
        type=str,
        nargs="+",
        default=[],
        help=
        "Whitelist of characters that may not be removed on restoring a model. "
        "For large dataset you can use this to skip the automatic codec computation "
        "(see --no_auto_compute_codec)")
    parser.add_argument(
        "--keep_loaded_codec",
        action='store_true',
        default=False,
        help="Fully include the codec of the loaded model to the new codec")
    parser.add_argument("--ensemble",
                        type=int,
                        default=-1,
                        help="Number of voting models")
    parser.add_argument("--masking_mode",
                        type=int,
                        default=0,
                        help="Do not use")  # TODO: remove

    # clipping
    parser.add_argument("--gradient_clipping_norm",
                        type=float,
                        default=5,
                        help="Clipping constant of the norm of the gradients.")

    # early stopping
    if "validation" not in omit:
        parser.add_argument(
            "--validation",
            type=str,
            nargs="+",
            help="Validation line files used for early stopping")
        parser.add_argument(
            "--validation_text_files",
            nargs="+",
            default=None,
            help=
            "Optional list of validation GT files if they are in other directory"
        )
        parser.add_argument(
            "--validation_extension",
            default=None,
            help=
            "Default extension of the gt files (expected to exist in same dir). "
            "By default same as gt_extension")
        parser.add_argument(
            "--validation_dataset",
            type=DataSetType.from_string,
            choices=list(DataSetType),
            default=None,
            help=
            "Default validation data set type. By default same as --dataset")
        parser.add_argument("--use_train_as_val",
                            action='store_true',
                            default=False)
        parser.add_argument(
            "--validation_split_ratio",
            type=float,
            default=None,
            help=
            "Use n percent of the training dataset for validation. Can not be used with --validation"
        )

    parser.add_argument(
        "--validation_data_on_the_fly",
        action='store_true',
        default=False,
        help=
        'Instead of preloading all data during the training, load the data on the fly. '
        'This is slower, but might be required for limited RAM or large dataset'
    )

    parser.add_argument("--early_stopping_frequency",
                        type=int,
                        default=1,
                        help="The frequency of early stopping [epochs].")
    parser.add_argument(
        "--early_stopping_nbest",
        type=int,
        default=5,
        help=
        "The number of models that must be worse than the current best model to stop"
    )
    if "early_stopping_best_model_prefix" not in omit:
        parser.add_argument(
            "--early_stopping_best_model_prefix",
            type=str,
            default="best",
            help="The prefix of the best model using early stopping")
    if "early_stopping_best_model_output_dir" not in omit:
        parser.add_argument(
            "--early_stopping_best_model_output_dir",
            type=str,
            default=None,
            help="Path where to store the best model. Default is output_dir")
    parser.add_argument(
        "--n_augmentations",
        type=float,
        default=0,
        help=
        "Amount of data augmentation per line (done before training). If this number is < 1 "
        "the amount is relative.")
    parser.add_argument(
        "--only_train_on_augmented",
        action="store_true",
        default=False,
        help=
        "When training with augmentations usually the model is retrained in a second run with "
        "only the non augmented data. This will take longer. Use this flag to disable this "
        "behavior.")

    # text normalization/regularization
    parser.add_argument("--text_regularization",
                        type=str,
                        nargs="+",
                        default=["extended"],
                        help="Text regularization to apply.")
    parser.add_argument(
        "--text_normalization",
        type=str,
        default="NFC",
        help="Unicode text normalization to apply. Defaults to NFC")
    parser.add_argument(
        "--data_preprocessing",
        nargs="+",
        type=str,
        choices=[
            k for k, p in Data.data_processor_factory().processors.items()
            if issubclass(p, ImageProcessor)
        ],
        default=[p.name for p in default_image_processors()])

    # text/line generation params (loaded from json files)
    parser.add_argument("--text_generator_params", type=str, default=None)
    parser.add_argument("--line_generator_params", type=str, default=None)

    # additional dataset args
    parser.add_argument("--dataset_pad", default=None, nargs='+', type=int)
    parser.add_argument("--pagexml_text_index", default=0)

    parser.add_argument("--debug", action='store_true')
    parser.add_argument("--train_verbose", default=1)
Ejemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--version',
                        action='version',
                        version='%(prog)s v' + __version__)
    parser.add_argument(
        "--files",
        nargs="+",
        help=
        "List all image files that shall be processed. Ground truth fils with the same "
        "base name but with '.gt.txt' as extension are required at the same location",
        required=True)
    parser.add_argument(
        "--text_files",
        nargs="+",
        default=None,
        help="Optional list of GT files if they are in other directory")
    parser.add_argument(
        "--gt_extension",
        default=None,
        help="Default extension of the gt files (expected to exist in same dir)"
    )
    parser.add_argument("--dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument("--line_height",
                        type=int,
                        default=48,
                        help="The line height")
    parser.add_argument("--pad",
                        type=int,
                        default=16,
                        help="Padding (left right) of the line")
    parser.add_argument("--processes",
                        type=int,
                        default=1,
                        help="The number of threads to use for all operations")

    parser.add_argument("--n_cols", type=int, default=1)
    parser.add_argument("--n_rows", type=int, default=5)
    parser.add_argument("--select", type=int, nargs="+", default=[])

    # text normalization/regularization
    parser.add_argument(
        "--n_augmentations",
        type=float,
        default=0,
        help=
        "Amount of data augmentation per line (done before training). If this number is < 1 "
        "the amount is relative.")
    parser.add_argument("--text_regularization",
                        type=str,
                        nargs="+",
                        default=["extended"],
                        help="Text regularization to apply.")
    parser.add_argument(
        "--text_normalization",
        type=str,
        default="NFC",
        help="Unicode text normalization to apply. Defaults to NFC")
    parser.add_argument(
        "--data_preprocessing",
        nargs="+",
        type=str,
        choices=[
            k for k, p in Data.data_processor_factory().processors.items()
            if issubclass(p, ImageProcessor)
        ],
        default=[p.name for p in default_image_processors()])
    parser.add_argument(
        "--bidi_dir",
        type=str,
        default=None,
        choices=["ltr", "rtl", "auto"],
        help=
        "The default text direction when preprocessing bidirectional text. Supported values "
        "are 'auto' to automatically detect the direction, 'ltr' and 'rtl' for left-to-right and "
        "right-to-left, respectively")

    parser.add_argument("--preload",
                        action='store_true',
                        help='Simulate preloading')
    parser.add_argument("--as_validation",
                        action='store_true',
                        help="Access as validation instead of training data.")

    args = parser.parse_args()

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    dataset_args = FileDataReaderArgs(pad=args.pad, )

    data_params: DataParams = Data.get_default_params()
    data_params.train = PipelineParams(
        type=args.dataset,
        remove_invalid=True,
        files=args.files,
        text_files=args.text_files,
        gt_extension=args.gt_extension,
        data_reader_args=dataset_args,
        batch_size=1,
        num_processes=args.processes,
    )
    data_params.val = data_params.train

    data_params.pre_processors_ = SamplePipelineParams(run_parallel=True)
    data_params.post_processors_ = SamplePipelineParams(run_parallel=True)
    for p in args.data_preprocessing:
        p_p = Data.data_processor_factory().processors[p].default_params()
        if 'pad' in p_p:
            p_p['pad'] = args.pad
        data_params.pre_processors_.sample_processors.append(
            DataProcessorFactoryParams(p, INPUT_PROCESSOR, p_p))

    # Text pre processing (reading)
    data_params.pre_processors_.sample_processors.extend([
        DataProcessorFactoryParams(
            TextNormalizer.__name__, TARGETS_PROCESSOR,
            {'unicode_normalization': args.text_normalization}),
        DataProcessorFactoryParams(
            TextRegularizer.__name__, TARGETS_PROCESSOR, {
                'replacements':
                default_text_regularizer_replacements(args.text_regularization)
            }),
        DataProcessorFactoryParams(StripTextProcessor.__name__,
                                   TARGETS_PROCESSOR)
    ])

    # Text post processing (prediction)
    data_params.post_processors_.sample_processors.extend([
        DataProcessorFactoryParams(
            TextNormalizer.__name__, TARGETS_PROCESSOR,
            {'unicode_normalization': args.text_normalization}),
        DataProcessorFactoryParams(
            TextRegularizer.__name__, TARGETS_PROCESSOR, {
                'replacements':
                default_text_regularizer_replacements(args.text_regularization)
            }),
        DataProcessorFactoryParams(StripTextProcessor.__name__,
                                   TARGETS_PROCESSOR)
    ])
    if args.bidi_dir:
        data_params.pre_processors_.sample_processors.append(
            DataProcessorFactoryParams(BidiTextProcessor.__name__,
                                       TARGETS_PROCESSOR,
                                       {'bidi_direction': args.bidi_dir}))
        data_params.post_processors_.sample_processors.append(
            DataProcessorFactoryParams(BidiTextProcessor.__name__,
                                       TARGETS_PROCESSOR,
                                       {'bidi_direction': args.bidi_dir}))

    data_params.pre_processors_.sample_processors.extend([
        DataProcessorFactoryParams(AugmentationProcessor.__name__,
                                   {PipelineMode.Training},
                                   {'augmenter_type': 'simple'}),
        # DataProcessorFactoryParams(PrepareSampleProcessor.__name__),  # NOT THIS, since, we want to access raw input
    ])

    data_params.data_aug_params = DataAugmentationAmount.from_factor(
        args.n_augmentations)
    data_params.line_height_ = args.line_height

    data = Data(data_params)
    data_pipeline = data.get_val_data(
    ) if args.as_validation else data.get_train_data()
    if not args.preload:
        reader: DataReader = data_pipeline.reader()
        if len(args.select) == 0:
            args.select = range(len(reader))
        else:
            reader._samples = [reader.samples()[i] for i in args.select]
    else:
        data.preload()
        data_pipeline = data.get_val_data(
        ) if args.as_validation else data.get_train_data()
        samples = data_pipeline.samples
        if len(args.select) == 0:
            args.select = range(len(samples))
        else:
            data_pipeline.samples = [samples[i] for i in args.select]

    f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all')
    row, col = 0, 0
    with data_pipeline as dataset:
        for i, (id, sample) in enumerate(
                zip(args.select,
                    dataset.generate_input_samples(auto_repeat=False))):
            line, text, params = sample
            if args.n_cols == 1:
                ax[row].imshow(line.transpose())
                ax[row].set_title("ID: {}\n{}".format(id, text))
            else:
                ax[row, col].imshow(line.transpose())
                ax[row, col].set_title("ID: {}\n{}".format(id, text))

            row += 1
            if row == args.n_rows:
                row = 0
                col += 1

            if col == args.n_cols or i == len(dataset) - 1:
                plt.show()
                f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all')
                row, col = 0, 0