Exemplo n.º 1
0
def default_trainer_params():
    p = CalamariTestScenario.default_trainer_params()

    with open(
            os.path.join(this_dir, "data", "line_generation_config",
                         "text_gen_params.json")) as f:
        text_gen_params = TextGeneratorParams.from_dict(json.load(f))

    with open(
            os.path.join(this_dir, "data", "line_generation_config",
                         "line_gen_params.json")) as f:
        line_gen_params = LineGeneratorParams.from_dict(json.load(f))

    p.codec.include_files = os.path.join(this_dir, "data",
                                         "line_generation_config",
                                         "whilelist.txt")
    p.codec.auto_compute = False

    p.gen.train = GeneratedLineDatasetParams(
        lines_per_epoch=10,
        preload=False,
        text_generator=text_gen_params,
        line_generator=line_gen_params,
    )
    p.gen.val = GeneratedLineDatasetParams(
        lines_per_epoch=10,
        preload=False,
        text_generator=text_gen_params,
        line_generator=line_gen_params,
    )

    p.gen.setup.val.batch_size = 1
    p.gen.setup.val.num_processes = 1
    p.gen.setup.train.batch_size = 1
    p.gen.setup.train.num_processes = 1
    p.epochs = 1
    p.scenario.data.pre_proc.run_parallel = False
    p.gen.__post_init__()
    p.scenario.data.__post_init__()
    p.scenario.__post_init__()
    p.__post_init__()
    return p
Exemplo n.º 2
0
                out.append(
                    Word(self.params.word_separator, Script.NORMAL, 0,
                         FontVariantType.NORMAL))

            charset = [self.charset, self.super_charset,
                       self.sub_charset][script]
            s = "".join(np.random.choice(charset, word_length))
            s = s.strip()

            out.append(Word(s, script, letter_spacing, variant))

        return out


if __name__ == "__main__":
    params = TextGeneratorParams()
    params.word_length_mean = 11
    params.word_length_sigma = 3
    params.number_of_words_mean = 7
    params.number_of_words_mean = 4
    params.word_separator = " "
    params.sub_script_p = 0.0
    params.super_script_p = 0.2
    params.letter_spacing_p = 0.5
    params.letter_spacing_mean = 1
    params.letter_spacing_sigma = 0.1
    params.charset.extend(
        list(
            "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789{}[]()_-.;:'\" "
        ))
    params.super_charset.extend(list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
Exemplo n.º 3
0
def run(args):
    # local imports (to prevent tensorflow from being imported to early)
    from calamari_ocr.ocr.scenario import Scenario
    from calamari_ocr.ocr.dataset.data import Data

    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                if key == 'dataset' or key == 'validation_dataset':
                    setattr(args, key, DataSetType.from_string(value))
                else:
                    setattr(args, key, value)

    check_train_args(args)

    if args.output_dir is not None:
        args.output_dir = os.path.abspath(args.output_dir)
        setup_log(args.output_dir, append=False)

    # parse whitelist
    whitelist = args.whitelist
    if len(whitelist) == 1:
        whitelist = list(whitelist[0])

    whitelist_files = glob_all(args.whitelist_files)
    for f in whitelist_files:
        with open(f) as txt:
            whitelist += list(txt.read())

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    if args.validation_extension is None:
        args.validation_extension = DataSetType.gt_extension(
            args.validation_dataset)

    if args.dataset == DataSetType.GENERATED_LINE or args.validation_dataset == DataSetType.GENERATED_LINE:
        if args.text_generator_params is not None:
            with open(args.text_generator_params, 'r') as f:
                args.text_generator_params = TextGeneratorParams.from_json(
                    f.read())
        else:
            args.text_generator_params = TextGeneratorParams()

        if args.line_generator_params is not None:
            with open(args.line_generator_params, 'r') as f:
                args.line_generator_params = LineGeneratorParams.from_json(
                    f.read())
        else:
            args.line_generator_params = LineGeneratorParams()

    dataset_args = FileDataReaderArgs(
        line_generator_params=args.line_generator_params,
        text_generator_params=args.text_generator_params,
        pad=args.dataset_pad,
        text_index=args.pagexml_text_index,
    )

    params: TrainerParams = Scenario.default_trainer_params()

    # =================================================================================================================
    # Data Params
    data_params: DataParams = params.scenario_params.data_params
    data_params.train = PipelineParams(
        type=args.dataset,
        skip_invalid=not args.no_skip_invalid_gt,
        remove_invalid=True,
        files=args.files,
        text_files=args.text_files,
        gt_extension=args.gt_extension,
        data_reader_args=dataset_args,
        batch_size=args.batch_size,
        num_processes=args.num_threads,
    )
    if args.validation_split_ratio:
        if args.validation is not None:
            raise ValueError("Set either validation_split_ratio or validation")
        if not 0 < args.validation_split_ratio < 1:
            raise ValueError("validation_split_ratio must be in (0, 1)")

        # resolve all files so we can split them
        data_params.train = data_params.train.prepare_for_mode(
            PipelineMode.Training)
        n = int(args.validation_split_ratio * len(data_params.train.files))
        logger.info(
            f"Splitting training and validation files with ratio {args.validation_split_ratio}: "
            f"{n}/{len(data_params.train.files) - n} for validation/training.")
        indices = list(range(len(data_params.train.files)))
        shuffle(indices)
        all_files = data_params.train.files
        all_text_files = data_params.train.text_files

        # split train and val img/gt files. Use train settings
        data_params.train.files = [all_files[i] for i in indices[:n]]
        if all_text_files is not None:
            assert (len(all_text_files) == len(all_files))
            data_params.train.text_files = [
                all_text_files[i] for i in indices[:n]
            ]
        data_params.val = PipelineParams(
            type=args.dataset,
            skip_invalid=not args.no_skip_invalid_gt,
            remove_invalid=True,
            files=[all_files[i] for i in indices[n:]],
            text_files=[all_text_files[i] for i in indices[n:]]
            if data_params.train.text_files is not None else None,
            gt_extension=args.gt_extension,
            data_reader_args=dataset_args,
            batch_size=args.batch_size,
            num_processes=args.num_threads,
        )
    elif args.validation:
        data_params.val = PipelineParams(
            type=args.validation_dataset,
            files=args.validation,
            text_files=args.validation_text_files,
            skip_invalid=not args.no_skip_invalid_gt,
            gt_extension=args.validation_extension,
            data_reader_args=dataset_args,
            batch_size=args.batch_size,
            num_processes=args.num_threads,
        )
    else:
        data_params.val = None

    data_params.pre_processors_ = SamplePipelineParams(run_parallel=True)
    data_params.post_processors_.run_parallel = SamplePipelineParams(
        run_parallel=False,
        sample_processors=[
            DataProcessorFactoryParams(ReshapeOutputsProcessor.__name__),
            DataProcessorFactoryParams(CTCDecoderProcessor.__name__),
        ])

    for p in args.data_preprocessing:
        p_p = Data.data_processor_factory().processors[p].default_params()
        if 'pad' in p_p:
            p_p['pad'] = args.pad
        data_params.pre_processors_.sample_processors.append(
            DataProcessorFactoryParams(p, INPUT_PROCESSOR, p_p))

    # Text pre processing (reading)
    data_params.pre_processors_.sample_processors.extend([
        DataProcessorFactoryParams(
            TextNormalizer.__name__, TARGETS_PROCESSOR,
            {'unicode_normalization': args.text_normalization}),
        DataProcessorFactoryParams(
            TextRegularizer.__name__, TARGETS_PROCESSOR, {
                'replacements':
                default_text_regularizer_replacements(args.text_regularization)
            }),
        DataProcessorFactoryParams(StripTextProcessor.__name__,
                                   TARGETS_PROCESSOR)
    ])

    # Text post processing (prediction)
    data_params.post_processors_.sample_processors.extend([
        DataProcessorFactoryParams(
            TextNormalizer.__name__, TARGETS_PROCESSOR,
            {'unicode_normalization': args.text_normalization}),
        DataProcessorFactoryParams(
            TextRegularizer.__name__, TARGETS_PROCESSOR, {
                'replacements':
                default_text_regularizer_replacements(args.text_regularization)
            }),
        DataProcessorFactoryParams(StripTextProcessor.__name__,
                                   TARGETS_PROCESSOR)
    ])
    if args.bidi_dir:
        data_params.pre_processors_.sample_processors.append(
            DataProcessorFactoryParams(BidiTextProcessor.__name__,
                                       TARGETS_PROCESSOR,
                                       {'bidi_direction': args.bidi_dir}))
        data_params.post_processors_.sample_processors.append(
            DataProcessorFactoryParams(BidiTextProcessor.__name__,
                                       TARGETS_PROCESSOR,
                                       {'bidi_direction': args.bidi_dir}))

    data_params.pre_processors_.sample_processors.extend([
        DataProcessorFactoryParams(AugmentationProcessor.__name__,
                                   {PipelineMode.Training},
                                   {'augmenter_type': 'simple'}),
        DataProcessorFactoryParams(PrepareSampleProcessor.__name__,
                                   INPUT_PROCESSOR),
    ])

    data_params.data_aug_params = DataAugmentationAmount.from_factor(
        args.n_augmentations)
    data_params.line_height_ = args.line_height

    # =================================================================================================================
    # Trainer Params
    params.calc_ema = args.ema_weights
    params.verbose = args.train_verbose
    params.force_eager = args.debug
    params.skip_model_load_test = not args.debug
    params.scenario_params.debug_graph_construction = args.debug
    params.epochs = args.epochs
    params.samples_per_epoch = int(
        args.samples_per_epoch) if args.samples_per_epoch >= 1 else -1
    params.scale_epoch_size = abs(
        args.samples_per_epoch) if args.samples_per_epoch < 1 else 1
    params.skip_load_model_test = True
    params.scenario_params.export_frozen = False
    params.checkpoint_save_freq_ = args.checkpoint_frequency if args.checkpoint_frequency >= 0 else args.early_stopping_frequency
    params.checkpoint_dir = args.output_dir
    params.test_every_n = args.display
    params.skip_invalid_gt = not args.no_skip_invalid_gt
    params.data_aug_retrain_on_original = not args.only_train_on_augmented
    params.use_training_as_validation = args.use_train_as_val
    if args.seed > 0:
        params.random_seed = args.seed

    params.optimizer_params.clip_grad = args.gradient_clipping_norm
    params.codec_whitelist = whitelist
    params.keep_loaded_codec = args.keep_loaded_codec
    params.preload_training = not args.train_data_on_the_fly
    params.preload_validation = not args.validation_data_on_the_fly
    params.warmstart_params.model = args.weights

    params.auto_compute_codec = not args.no_auto_compute_codec
    params.progress_bar = not args.no_progress_bars

    params.early_stopping_params.frequency = args.early_stopping_frequency
    params.early_stopping_params.upper_threshold = 0.9
    params.early_stopping_params.lower_threshold = 1.0 - args.early_stopping_at_accuracy
    params.early_stopping_params.n_to_go = args.early_stopping_nbest
    params.early_stopping_params.best_model_name = ''
    params.early_stopping_params.best_model_output_dir = args.early_stopping_best_model_output_dir
    params.scenario_params.default_serve_dir_ = f'{args.early_stopping_best_model_prefix}.ckpt.h5'
    params.scenario_params.trainer_params_filename_ = f'{args.early_stopping_best_model_prefix}.ckpt.json'

    # =================================================================================================================
    # Model params
    params_from_definition_string(args.network, params)
    params.scenario_params.model_params.ensemble = args.ensemble
    params.scenario_params.model_params.masking_mode = args.masking_mode

    scenario = Scenario(params.scenario_params)
    trainer = scenario.create_trainer(params)
    trainer.train()