def finish_chunck(self):
        if len(self.text) == 0:
            return

        codec = self.compute_codec()

        filename = "{}_{:03d}{}".format(
            self.output_filename, self.current_chunk,
            DataSetType.gt_extension(DataSetType.HDF5))
        self.files.append(filename)
        file = h5py.File(filename, 'w')
        dti32 = h5py.special_dtype(vlen=np.dtype('int32'))
        dtui8 = h5py.special_dtype(vlen=np.dtype('uint8'))
        file.create_dataset('transcripts', (len(self.text), ),
                            dtype=dti32,
                            compression='gzip')
        file.create_dataset('images_dims',
                            data=[d.shape for d in self.data],
                            dtype=int)
        file.create_dataset('images', (len(self.text), ),
                            dtype=dtui8,
                            compression='gzip')
        file.create_dataset('codec', data=list(map(ord, codec)))
        file['transcripts'][...] = [
            list(map(codec.index, d)) for d in self.text
        ]
        file['images'][...] = [d.reshape(-1) for d in self.data]
        file.close()

        self.current_chunk += 1
        self.data = []
        self.text = []
예제 #2
0
def create_train_dataset(args, dataset_args=None):
    gt_extension = args.gt_extension if args.gt_extension is not None else DataSetType.gt_extension(args.dataset)

    # Training dataset
    print("Resolving input files")
    input_image_files = sorted(glob_all(args.files))
    if not args.text_files:
        if gt_extension:
            gt_txt_files = [split_all_ext(f)[0] + gt_extension for f in input_image_files]
        else:
            gt_txt_files = [None] * len(input_image_files)
    else:
        gt_txt_files = sorted(glob_all(args.text_files))
        input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files)
        for img, gt in zip(input_image_files, gt_txt_files):
            if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                raise Exception("Expected identical basenames of file: {} and {}".format(img, gt))

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception("Some image are occurring more than once in the data set.")

    dataset = create_dataset(
        args.dataset,
        DataSetMode.TRAIN,
        images=input_image_files,
        texts=gt_txt_files,
        skip_invalid=not args.no_skip_invalid_gt,
        args=dataset_args if dataset_args else {},
    )
    print("Found {} files in the dataset".format(len(dataset)))
    return dataset
예제 #3
0
    def __init__(
        self,
        mode: DataSetMode,
        files: List[str] = None,
        xmlfiles: List[str] = None,
        skip_invalid=False,
        remove_invalid=True,
        binary=False,
        non_existing_as_empty=False,
    ):
        """ Create a dataset from a Path as String

        Parameters
         ----------
        files : [], required
            image files
        skip_invalid : bool, optional
            skip invalid files
        remove_invalid : bool, optional
            remove invalid files
        """

        super().__init__(mode, skip_invalid, remove_invalid)

        self.xmlfiles = xmlfiles if xmlfiles else []
        self.files = files if files else []

        self._non_existing_as_empty = non_existing_as_empty
        if len(self.xmlfiles) == 0:
            from calamari_ocr.ocr.datasets import DataSetType
            self.xmlfiles = [
                split_all_ext(p)[0] +
                DataSetType.gt_extension(DataSetType.ABBYY) for p in files
            ]

        if len(self.files) == 0:
            self.files = [None] * len(self.xmlfiles)

        self.book = XMLReader(self.files, self.xmlfiles, skip_invalid,
                              remove_invalid).read()
        self.binary = binary

        for p, page in enumerate(self.book.pages):
            for l, line in enumerate(page.getLines()):
                for f, fo in enumerate(line.formats):
                    self.add_sample({
                        "image_path":
                        page.imgFile,
                        "xml_path":
                        page.xmlFile,
                        "id":
                        "{}_{}_{}_{}".format(
                            os.path.splitext(page.xmlFile if page.
                                             xmlFile else page.imgFile)[0], p,
                            l, f),
                        "line":
                        line,
                        "format":
                        fo,
                    })
예제 #4
0
def create_train_dataset(cfg: CfgNode, dataset_args=None):
    gt_extension = cfg.DATASET.TRAIN.GT_EXTENSION if cfg.DATASET.TRAIN.GT_EXTENSION is not False else DataSetType.gt_extension(
        cfg.DATASET.TRAIN.TYPE)

    # Training dataset
    print("Resolving input files")
    input_image_files = sorted(glob_all(cfg.DATASET.TRAIN.PATH))
    if not cfg.DATASET.TRAIN.TEXT_FILES:
        if gt_extension:
            gt_txt_files = [
                split_all_ext(f)[0] + gt_extension for f in input_image_files
            ]
        else:
            gt_txt_files = [None] * len(input_image_files)
    else:
        gt_txt_files = sorted(glob_all(cfg.DATASET.TRAIN.TEXT_FILES))
        input_image_files, gt_txt_files = keep_files_with_same_file_name(
            input_image_files, gt_txt_files)
        for img, gt in zip(input_image_files, gt_txt_files):
            if split_all_ext(os.path.basename(img))[0] != split_all_ext(
                    os.path.basename(gt))[0]:
                raise Exception(
                    "Expected identical basenames of file: {} and {}".format(
                        img, gt))

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception(
            "Some image are occurring more than once in the data set.")

    dataset = create_dataset(
        cfg.DATASET.TRAIN.TYPE,
        DataSetMode.TRAIN,
        images=input_image_files,
        texts=gt_txt_files,
        skip_invalid=not cfg.DATALOADER.NO_SKIP_INVALID_GT,
        args=dataset_args if dataset_args else {},
    )
    print("Found {} files in the dataset".format(len(dataset)))
    return dataset
예제 #5
0
def run(cfg: CfgNode):

    # check if loading a json file
    if len(cfg.DATASET.TRAIN.PATH) == 1 and cfg.DATASET.TRAIN.PATH[0].endswith(
            "json"):
        import json
        with open(cfg.DATASET.TRAIN.PATH[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                if key == 'dataset' or key == 'validation_dataset':
                    setattr(cfg, key, DataSetType.from_string(value))
                else:
                    setattr(cfg, key, value)

    # parse whitelist
    whitelist = cfg.MODEL.CODEX.WHITELIST
    if len(whitelist) == 1:
        whitelist = list(whitelist[0])

    whitelist_files = glob_all(cfg.MODEL.CODEX.WHITELIST_FILES)
    for f in whitelist_files:
        with open(f) as txt:
            whitelist += list(txt.read())

    if cfg.DATASET.TRAIN.GT_EXTENSION is False:
        cfg.DATASET.TRAIN.GT_EXTENSION = DataSetType.gt_extension(
            cfg.DATASET.TRAIN.TYPE)

    if cfg.DATASET.VALID.GT_EXTENSION is False:
        cfg.DATASET.VALID.GT_EXTENSION = DataSetType.gt_extension(
            cfg.DATASET.VALID.TYPE)

    text_generator_params = TextGeneratorParameters()

    line_generator_params = LineGeneratorParameters()

    dataset_args = {
        'line_generator_params': line_generator_params,
        'text_generator_params': text_generator_params,
        'pad': None,
        'text_index': 0,
    }

    # Training dataset
    dataset = create_train_dataset(cfg, dataset_args)

    # Validation dataset
    validation_dataset_list = create_test_dataset(cfg, dataset_args)

    params = CheckpointParams()

    params.max_iters = cfg.SOLVER.MAX_ITER
    params.stats_size = cfg.STATS_SIZE
    params.batch_size = cfg.SOLVER.BATCH_SIZE
    params.checkpoint_frequency = cfg.SOLVER.CHECKPOINT_FREQ if cfg.SOLVER.CHECKPOINT_FREQ >= 0 else cfg.SOLVER.EARLY_STOPPING_FREQ
    params.output_dir = cfg.OUTPUT_DIR
    params.output_model_prefix = cfg.OUTPUT_MODEL_PREFIX
    params.display = cfg.DISPLAY
    params.skip_invalid_gt = not cfg.DATALOADER.NO_SKIP_INVALID_GT
    params.processes = cfg.NUM_THREADS
    params.data_aug_retrain_on_original = not cfg.DATALOADER.ONLY_TRAIN_ON_AUGMENTED

    params.early_stopping_at_acc = cfg.SOLVER.EARLY_STOPPING_AT_ACC
    params.early_stopping_frequency = cfg.SOLVER.EARLY_STOPPING_FREQ
    params.early_stopping_nbest = cfg.SOLVER.EARLY_STOPPING_NBEST
    params.early_stopping_best_model_prefix = cfg.EARLY_STOPPING_BEST_MODEL_PREFIX
    params.early_stopping_best_model_output_dir = \
        cfg.EARLY_STOPPING_BEST_MODEL_OUTPUT_DIR if cfg.EARLY_STOPPING_BEST_MODEL_OUTPUT_DIR else cfg.OUTPUT_DIR

    if cfg.INPUT.DATA_PREPROCESSING is False or len(
            cfg.INPUT.DATA_PREPROCESSING) == 0:
        cfg.INPUT.DATA_PREPROCESSING = [
            DataPreprocessorParams.DEFAULT_NORMALIZER
        ]

    params.model.data_preprocessor.type = DataPreprocessorParams.MULTI_NORMALIZER
    for preproc in cfg.INPUT.DATA_PREPROCESSING:
        pp = params.model.data_preprocessor.children.add()
        pp.type = DataPreprocessorParams.Type.Value(preproc) if isinstance(
            preproc, str) else preproc
        pp.line_height = cfg.INPUT.LINE_HEIGHT
        pp.pad = cfg.INPUT.PAD

    # Text pre processing (reading)
    params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(
        params.model.text_preprocessor.children.add(),
        default=cfg.INPUT.TEXT_NORMALIZATION)
    default_text_regularizer_params(
        params.model.text_preprocessor.children.add(),
        groups=cfg.INPUT.TEXT_REGULARIZATION)
    strip_processor_params = params.model.text_preprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    # Text post processing (prediction)
    params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(
        params.model.text_postprocessor.children.add(),
        default=cfg.INPUT.TEXT_NORMALIZATION)
    default_text_regularizer_params(
        params.model.text_postprocessor.children.add(),
        groups=cfg.INPUT.TEXT_REGULARIZATION)
    strip_processor_params = params.model.text_postprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    if cfg.SEED > 0:
        params.model.network.backend.random_seed = cfg.SEED

    if cfg.INPUT.BIDI_DIR:
        # change bidirectional text direction if desired
        bidi_dir_to_enum = {
            "rtl": TextProcessorParams.BIDI_RTL,
            "ltr": TextProcessorParams.BIDI_LTR,
            "auto": TextProcessorParams.BIDI_AUTO
        }

        bidi_processor_params = params.model.text_preprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = bidi_dir_to_enum[
            cfg.INPUT.BIDI_DIR]

        bidi_processor_params = params.model.text_postprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO

    params.model.line_height = cfg.INPUT.LINE_HEIGHT
    params.model.network.learning_rate = cfg.SOLVER.LR
    params.model.network.lr_decay = cfg.SOLVER.LR_DECAY
    params.model.network.lr_decay_freq = cfg.SOLVER.LR_DECAY_FREQ
    params.model.network.train_last_n_layer = cfg.SOLVER.TRAIN_LAST_N_LAYER
    network_params_from_definition_string(cfg.MODEL.NETWORK,
                                          params.model.network)
    params.model.network.clipping_norm = cfg.SOLVER.GRADIENT_CLIPPING_NORM
    params.model.network.backend.num_inter_threads = 0
    params.model.network.backend.num_intra_threads = 0
    params.model.network.backend.shuffle_buffer_size = cfg.DATALOADER.SHUFFLE_BUFFER_SIZE

    if cfg.MODEL.WEIGHTS == "":
        weights = None
    else:
        weights = cfg.MODEL.WEIGHTS

    # create the actual trainer
    trainer = Trainer(
        params,
        dataset,
        validation_dataset=validation_dataset_list,
        data_augmenter=SimpleDataAugmenter(),
        n_augmentations=cfg.INPUT.N_AUGMENT,
        weights=weights,
        codec_whitelist=whitelist,
        keep_loaded_codec=cfg.MODEL.CODEX.KEEP_LOADED_CODEC,
        preload_training=not cfg.DATALOADER.TRAIN_ON_THE_FLY,
        preload_validation=not cfg.DATALOADER.VALID_ON_THE_FLY,
    )
    trainer.train(auto_compute_codec=not cfg.MODEL.CODEX.SEE_WHITELIST,
                  progress_bar=not cfg.NO_PROGRESS_BAR)
예제 #6
0
def run(args):

    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        import json
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                if key == 'dataset' or key == 'validation_dataset':
                    setattr(args, key, DataSetType.from_string(value))
                else:
                    setattr(args, key, value)

    # parse whitelist
    whitelist = args.whitelist
    if len(whitelist) == 1:
        whitelist = list(whitelist[0])

    whitelist_files = glob_all(args.whitelist_files)
    for f in whitelist_files:
        with open(f) as txt:
            whitelist += list(txt.read())

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    if args.validation_extension is None:
        args.validation_extension = DataSetType.gt_extension(args.validation_dataset)

    if args.text_generator_params is not None:
        with open(args.text_generator_params, 'r') as f:
            args.text_generator_params = json_format.Parse(f.read(), TextGeneratorParameters())
    else:
        args.text_generator_params = TextGeneratorParameters()

    if args.line_generator_params is not None:
        with open(args.line_generator_params, 'r') as f:
            args.line_generator_params = json_format.Parse(f.read(), LineGeneratorParameters())
    else:
        args.line_generator_params = LineGeneratorParameters()

    dataset_args = {
        'line_generator_params': args.line_generator_params,
        'text_generator_params': args.text_generator_params,
        'pad': args.dataset_pad,
        'text_index': args.pagexml_text_index,
    }

    # Training dataset
    dataset = create_train_dataset(args, dataset_args)

    # Validation dataset
    if args.validation:
        print("Resolving validation files")
        validation_image_files = glob_all(args.validation)
        if not args.validation_text_files:
            val_txt_files = [split_all_ext(f)[0] + args.validation_extension for f in validation_image_files]
        else:
            val_txt_files = sorted(glob_all(args.validation_text_files))
            validation_image_files, val_txt_files = keep_files_with_same_file_name(validation_image_files, val_txt_files)
            for img, gt in zip(validation_image_files, val_txt_files):
                if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                    raise Exception("Expected identical basenames of validation file: {} and {}".format(img, gt))

        if len(set(val_txt_files)) != len(val_txt_files):
            raise Exception("Some validation images are occurring more than once in the data set.")

        validation_dataset = create_dataset(
            args.validation_dataset,
            DataSetMode.TRAIN,
            images=validation_image_files,
            texts=val_txt_files,
            skip_invalid=not args.no_skip_invalid_gt,
            args=dataset_args,
        )
        print("Found {} files in the validation dataset".format(len(validation_dataset)))
    else:
        validation_dataset = None

    params = CheckpointParams()

    params.max_iters = args.max_iters
    params.stats_size = args.stats_size
    params.batch_size = args.batch_size
    params.checkpoint_frequency = args.checkpoint_frequency if args.checkpoint_frequency >= 0 else args.early_stopping_frequency
    params.output_dir = args.output_dir
    params.output_model_prefix = args.output_model_prefix
    params.display = args.display
    params.skip_invalid_gt = not args.no_skip_invalid_gt
    params.processes = args.num_threads
    params.data_aug_retrain_on_original = not args.only_train_on_augmented

    params.early_stopping_frequency = args.early_stopping_frequency
    params.early_stopping_nbest = args.early_stopping_nbest
    params.early_stopping_best_model_prefix = args.early_stopping_best_model_prefix
    params.early_stopping_best_model_output_dir = \
        args.early_stopping_best_model_output_dir if args.early_stopping_best_model_output_dir else args.output_dir

    if args.data_preprocessing is None or len(args.data_preprocessing) == 0:
        args.data_preprocessing = [DataPreprocessorParams.DEFAULT_NORMALIZER]

    params.model.data_preprocessor.type = DataPreprocessorParams.MULTI_NORMALIZER
    for preproc in args.data_preprocessing:
        pp = params.model.data_preprocessor.children.add()
        pp.type = DataPreprocessorParams.Type.Value(preproc) if isinstance(preproc, str) else preproc
        pp.line_height = args.line_height
        pp.pad = args.pad

    # Text pre processing (reading)
    params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(params.model.text_preprocessor.children.add(), default=args.text_normalization)
    default_text_regularizer_params(params.model.text_preprocessor.children.add(), groups=args.text_regularization)
    strip_processor_params = params.model.text_preprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    # Text post processing (prediction)
    params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(params.model.text_postprocessor.children.add(), default=args.text_normalization)
    default_text_regularizer_params(params.model.text_postprocessor.children.add(), groups=args.text_regularization)
    strip_processor_params = params.model.text_postprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    if args.seed > 0:
        params.model.network.backend.random_seed = args.seed

    if args.bidi_dir:
        # change bidirectional text direction if desired
        bidi_dir_to_enum = {"rtl": TextProcessorParams.BIDI_RTL, "ltr": TextProcessorParams.BIDI_LTR,
                            "auto": TextProcessorParams.BIDI_AUTO}

        bidi_processor_params = params.model.text_preprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir]

        bidi_processor_params = params.model.text_postprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO

    params.model.line_height = args.line_height

    network_params_from_definition_string(args.network, params.model.network)
    params.model.network.clipping_mode = NetworkParams.ClippingMode.Value("CLIP_" + args.gradient_clipping_mode.upper())
    params.model.network.clipping_constant = args.gradient_clipping_const
    params.model.network.backend.fuzzy_ctc_library_path = args.fuzzy_ctc_library_path
    params.model.network.backend.num_inter_threads = args.num_inter_threads
    params.model.network.backend.num_intra_threads = args.num_intra_threads
    params.model.network.backend.shuffle_buffer_size = args.shuffle_buffer_size

    # create the actual trainer
    trainer = Trainer(params,
                      dataset,
                      validation_dataset=validation_dataset,
                      data_augmenter=SimpleDataAugmenter(),
                      n_augmentations=args.n_augmentations,
                      weights=args.weights,
                      codec_whitelist=whitelist,
                      keep_loaded_codec=args.keep_loaded_codec,
                      preload_training=not args.train_data_on_the_fly,
                      preload_validation=not args.validation_data_on_the_fly,
                      )
    trainer.train(
        auto_compute_codec=not args.no_auto_compute_codec,
        progress_bar=not args.no_progress_bars
    )
예제 #7
0
def run(args):
    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        import json
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    # checks
    if args.extended_prediction_data_format not in ["pred", "json"]:
        raise Exception(
            "Only 'pred' and 'json' are allowed extended prediction data formats"
        )

    # add json as extension, resolve wildcard, expand user, ... and remove .json again
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json")
                       for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]

    args.extension = args.extension if args.extension else DataSetType.pred_extension(
        args.dataset)

    # create ctc decoder
    ctc_decoder_params = create_ctc_decoder_params(args)

    # create voter
    voter_params = VoterParams()
    voter_params.type = VoterParams.Type.Value(args.voter.upper())
    voter = voter_from_proto(voter_params)

    # load files
    input_image_files = glob_all(args.files)
    if args.text_files:
        args.text_files = glob_all(args.text_files)

    # skip invalid files and remove them, there wont be predictions of invalid files
    dataset = create_dataset(
        args.dataset,
        DataSetMode.PREDICT,
        input_image_files,
        args.text_files,
        skip_invalid=True,
        remove_invalid=True,
        args={
            'text_index': args.pagexml_text_index,
            'pad': args.dataset_pad,
        },
    )

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception(
            "Empty dataset provided. Check your files argument (got {})!".
            format(args.files))

    # predict for all models
    predictor = MultiPredictor(checkpoints=args.checkpoint,
                               batch_size=args.batch_size,
                               processes=args.processes,
                               ctc_decoder_params=ctc_decoder_params)
    do_prediction = predictor.predict_dataset(
        dataset, progress_bar=not args.no_progress_bars)

    avg_sentence_confidence = 0
    n_predictions = 0

    dataset.prepare_store()

    # output the voted results to the appropriate files
    for result, sample in do_prediction:
        n_predictions += 1
        for i, p in enumerate(result):
            p.prediction.id = "fold_{}".format(i)

        # vote the results (if only one model is given, this will just return the sentences)
        prediction = voter.vote_prediction_result(result)
        prediction.id = "voted"
        sentence = prediction.sentence
        avg_sentence_confidence += prediction.avg_char_probability
        if args.verbose:
            lr = "\u202A\u202B"
            print("{}: '{}{}{}'".format(sample['id'],
                                        lr[get_base_level(sentence)], sentence,
                                        "\u202C"))

        output_dir = args.output_dir

        dataset.store_text(sentence,
                           sample,
                           output_dir=output_dir,
                           extension=args.extension)

        if args.extended_prediction_data:
            ps = Predictions()
            ps.line_path = sample[
                'image_path'] if 'image_path' in sample else sample['id']
            ps.predictions.extend([prediction] +
                                  [r.prediction for r in result])
            output_dir = output_dir if output_dir else os.path.dirname(
                ps.line_path)
            if not os.path.exists(output_dir):
                os.mkdir(output_dir)

            if args.extended_prediction_data_format == "pred":
                data = ps.SerializeToString()
            elif args.extended_prediction_data_format == "json":
                # remove logits
                for prediction in ps.predictions:
                    prediction.logits.rows = 0
                    prediction.logits.cols = 0
                    prediction.logits.data[:] = []

                data = MessageToJson(ps, including_default_value_fields=True)
            else:
                raise Exception("Unknown prediction format.")

            dataset.store_extended_prediction(
                data,
                sample,
                output_dir=output_dir,
                extension=args.extended_prediction_data_format)

    print("Average sentence confidence: {:.2%}".format(
        avg_sentence_confidence / n_predictions))

    dataset.store(args.extension)
    print("All files written")
예제 #8
0
파일: train.py 프로젝트: AIRob/calamari
def run(args):

    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        import json
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    # parse whitelist
    whitelist = args.whitelist
    if len(whitelist) == 1:
        whitelist = list(whitelist[0])

    whitelist_files = glob_all(args.whitelist_files)
    for f in whitelist_files:
        with open(f) as txt:
            whitelist += list(txt.read())

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    if args.validation_extension is None:
        args.validation_extension = DataSetType.gt_extension(args.validation_dataset)

    # Training dataset
    print("Resolving input files")
    input_image_files = sorted(glob_all(args.files))
    if not args.text_files:
        gt_txt_files = [split_all_ext(f)[0] + args.gt_extension for f in input_image_files]
    else:
        gt_txt_files = sorted(glob_all(args.text_files))
        input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files)
        for img, gt in zip(input_image_files, gt_txt_files):
            if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                raise Exception("Expected identical basenames of file: {} and {}".format(img, gt))

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception("Some image are occurring more than once in the data set.")

    dataset = create_dataset(
        args.dataset,
        DataSetMode.TRAIN,
        images=input_image_files,
        texts=gt_txt_files,
        skip_invalid=not args.no_skip_invalid_gt
    )
    print("Found {} files in the dataset".format(len(dataset)))

    # Validation dataset
    if args.validation:
        print("Resolving validation files")
        validation_image_files = glob_all(args.validation)
        if not args.validation_text_files:
            val_txt_files = [split_all_ext(f)[0] + args.validation_extension for f in validation_image_files]
        else:
            val_txt_files = sorted(glob_all(args.validation_text_files))
            validation_image_files, val_txt_files = keep_files_with_same_file_name(validation_image_files, val_txt_files)
            for img, gt in zip(validation_image_files, val_txt_files):
                if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                    raise Exception("Expected identical basenames of validation file: {} and {}".format(img, gt))

        if len(set(val_txt_files)) != len(val_txt_files):
            raise Exception("Some validation images are occurring more than once in the data set.")

        validation_dataset = create_dataset(
            args.validation_dataset,
            DataSetMode.TRAIN,
            images=validation_image_files,
            texts=val_txt_files,
            skip_invalid=not args.no_skip_invalid_gt)
        print("Found {} files in the validation dataset".format(len(validation_dataset)))
    else:
        validation_dataset = None

    params = CheckpointParams()

    params.max_iters = args.max_iters
    params.stats_size = args.stats_size
    params.batch_size = args.batch_size
    params.checkpoint_frequency = args.checkpoint_frequency if args.checkpoint_frequency >= 0 else args.early_stopping_frequency
    params.output_dir = args.output_dir
    params.output_model_prefix = args.output_model_prefix
    params.display = args.display
    params.skip_invalid_gt = not args.no_skip_invalid_gt
    params.processes = args.num_threads
    params.data_aug_retrain_on_original = not args.only_train_on_augmented

    params.early_stopping_frequency = args.early_stopping_frequency
    params.early_stopping_nbest = args.early_stopping_nbest
    params.early_stopping_best_model_prefix = args.early_stopping_best_model_prefix
    params.early_stopping_best_model_output_dir = \
        args.early_stopping_best_model_output_dir if args.early_stopping_best_model_output_dir else args.output_dir

    params.model.data_preprocessor.type = DataPreprocessorParams.DEFAULT_NORMALIZER
    params.model.data_preprocessor.line_height = args.line_height
    params.model.data_preprocessor.pad = args.pad

    # Text pre processing (reading)
    params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(params.model.text_preprocessor.children.add(), default=args.text_normalization)
    default_text_regularizer_params(params.model.text_preprocessor.children.add(), groups=args.text_regularization)
    strip_processor_params = params.model.text_preprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    # Text post processing (prediction)
    params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(params.model.text_postprocessor.children.add(), default=args.text_normalization)
    default_text_regularizer_params(params.model.text_postprocessor.children.add(), groups=args.text_regularization)
    strip_processor_params = params.model.text_postprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    if args.seed > 0:
        params.model.network.backend.random_seed = args.seed

    if args.bidi_dir:
        # change bidirectional text direction if desired
        bidi_dir_to_enum = {"rtl": TextProcessorParams.BIDI_RTL, "ltr": TextProcessorParams.BIDI_LTR,
                            "auto": TextProcessorParams.BIDI_AUTO}

        bidi_processor_params = params.model.text_preprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir]

        bidi_processor_params = params.model.text_postprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO

    params.model.line_height = args.line_height

    network_params_from_definition_string(args.network, params.model.network)
    params.model.network.clipping_mode = NetworkParams.ClippingMode.Value("CLIP_" + args.gradient_clipping_mode.upper())
    params.model.network.clipping_constant = args.gradient_clipping_const
    params.model.network.backend.fuzzy_ctc_library_path = args.fuzzy_ctc_library_path
    params.model.network.backend.num_inter_threads = args.num_inter_threads
    params.model.network.backend.num_intra_threads = args.num_intra_threads

    # create the actual trainer
    trainer = Trainer(params,
                      dataset,
                      validation_dataset=validation_dataset,
                      data_augmenter=SimpleDataAugmenter(),
                      n_augmentations=args.n_augmentations,
                      weights=args.weights,
                      codec_whitelist=whitelist,
                      preload_training=not args.train_data_on_the_fly,
                      preload_validation=not args.validation_data_on_the_fly,
                      )
    trainer.train(
        auto_compute_codec=not args.no_auto_compute_codec,
        progress_bar=not args.no_progress_bars
    )