コード例 #1
0
    def __init__(self,
                 checkpoint=None,
                 text_postproc=None,
                 data_preproc=None,
                 codec=None,
                 backend=None):
        self.backend = backend
        self.checkpoint = checkpoint
        self.codec = codec

        if checkpoint:
            if backend:
                raise Exception(
                    "Either a checkpoint or a backend can be provided")

            with open(checkpoint + '.json', 'r') as f:
                checkpoint_params = json_format.Parse(f.read(),
                                                      CheckpointParams())
                self.model_params = checkpoint_params.model

            self.network_params = self.model_params.network
            self.backend = create_backend_from_proto(self.network_params,
                                                     restore=self.checkpoint)
            self.text_postproc = text_postproc if text_postproc else text_processor_from_proto(
                self.model_params.text_postprocessor, "post")
            self.data_preproc = data_preproc if data_preproc else data_processor_from_proto(
                self.model_params.data_preprocessor)
        elif backend:
            self.model_params = None
            self.network_params = backend.network_proto
            self.text_postproc = text_postproc
            self.data_preproc = data_preproc
        else:
            raise Exception(
                "Either a checkpoint or a existing backend must be provided")
コード例 #2
0
ファイル: checkpoint.py プロジェクト: maxwehner/Calamari-OCR
    def __init__(self, json_path: str, auto_update=True, dry_run=False):
        self.json_path = json_path if json_path.endswith(
            '.json') else json_path + '.json'
        self.json_path = os.path.abspath(
            os.path.expanduser(os.path.expandvars(self.json_path)))
        self.ckpt_path = os.path.splitext(self.json_path)[0]
        self.dry_run = dry_run

        # do not parse as proto, since some parameters might have changed
        with open(self.json_path, 'r') as f:
            self.json = json.load(f)

            self.version = self.json['version'] if 'version' in self.json else 0

        if self.version != Checkpoint.VERSION:
            if auto_update:
                self.update_checkpoint()
            else:
                raise Exception(
                    "Version of checkpoint is {} but {} is required. Please upgrade the model or "
                    "set the auto update flag.".format(self.version,
                                                       Checkpoint.VERSION))

        else:
            print("Checkpoint version {} is up-to-date.".format(self.version))

        with open(self.json_path, 'r') as f:
            self.checkpoint = json_format.Parse(f.read(), CheckpointParams())
コード例 #3
0
ファイル: resume_training.py プロジェクト: zzmcdc/calamari
def main():
    parser = ArgumentParser()
    parser.add_argument("--checkpoint",
                        type=str,
                        required=True,
                        help="The checkpoint used to resume")
    parser.add_argument("--validation",
                        type=str,
                        nargs="+",
                        help="Validation line files used for early stopping")
    parser.add_argument("files",
                        type=str,
                        nargs="+",
                        help="The files to use for training")

    args = parser.parse_args()

    # Train dataset
    input_image_files = glob_all(args.files)
    gt_txt_files = [split_all_ext(f)[0] + ".gt.txt" for f in input_image_files]

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception(
            "Some image are occurring more than once in the data set.")

    dataset = FileDataSet(input_image_files, gt_txt_files)

    print("Found {} files in the dataset".format(len(dataset)))

    # Validation dataset
    if args.validation:
        validation_image_files = glob_all(args.validation)
        val_txt_files = [
            split_all_ext(f)[0] + ".gt.txt" for f in validation_image_files
        ]

        if len(set(val_txt_files)) != len(val_txt_files):
            raise Exception(
                "Some validation images are occurring more than once in the data set."
            )

        validation_dataset = FileDataSet(validation_image_files, val_txt_files)
        print("Found {} files in the validation dataset".format(
            len(validation_dataset)))
    else:
        validation_dataset = None

    with open(args.checkpoint + '.json', 'r') as f:
        checkpoint_params = json_format.Parse(f.read(), CheckpointParams())

        trainer = Trainer(checkpoint_params,
                          dataset,
                          validation_dataset=validation_dataset,
                          restore=args.checkpoint)
        trainer.train(progress_bar=True)
コード例 #4
0
def params_from_args(args):
    """
    Turn args to calamari into params
    """
    params = CheckpointParams()
    for attr in ["max_iters", "stats_size", "batch_size", "checkpoint_frequency", "output_dir",
                 "output_model_prefix", "display", "early_stopping_nbest",
                 "early_stopping_best_model_prefix"]:
        setattr(params, attr, getattr(args, attr))

    params.processes = args.num_threads
    params.skip_invalid_gt = not args.no_skip_invalid_gt
    params.early_stopping_frequency = args.early_stopping_frequency\
        if args.early_stopping_frequency >= 0 else args.checkpoint_frequency
    params.early_stopping_best_model_output_dir = args.early_stopping_best_model_output_dir\
        if args.early_stopping_best_model_output_dir else args.output_dir

    params.model.data_preprocessor.type = DataPreprocessorParams.DEFAULT_NORMALIZER
    params.model.data_preprocessor.line_height = args.line_height
    params.model.data_preprocessor.pad = args.pad

    # Text pre processing (reading)
    params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(params.model.text_preprocessor.children.add(), default=args.text_normalization)
    default_text_regularizer_params(params.model.text_preprocessor.children.add(), groups=args.text_regularization)
    strip_processor_params = params.model.text_preprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    # Text post processing (prediction)
    params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(params.model.text_postprocessor.children.add(), default=args.text_normalization)
    default_text_regularizer_params(params.model.text_postprocessor.children.add(), groups=args.text_regularization)
    strip_processor_params = params.model.text_postprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    if args.seed > 0:
        params.model.network.backend.random_seed = args.seed

    if args.bidi_dir:
        # change bidirectional text direction if desired
        bidi_dir_to_enum = {"rtl": TextProcessorParams.BIDI_RTL, "ltr": TextProcessorParams.BIDI_LTR,
                            "auto": TextProcessorParams.BIDI_AUTO}

        bidi_processor_params = params.model.text_preprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir]

        bidi_processor_params = params.model.text_postprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO

    params.model.line_height = args.line_height
    network_params_from_definition_string(args.network, params.model.network)
    params.model.network.clipping_mode = NetworkParams.ClippingMode.Value("CLIP_" + args.gradient_clipping_mode.upper())
    params.model.network.clipping_constant = args.gradient_clipping_const
    params.model.network.backend.fuzzy_ctc_library_path = args.fuzzy_ctc_library_path
    params.model.network.backend.num_inter_threads = args.num_inter_threads
    params.model.network.backend.num_intra_threads = args.num_intra_threads
    return params
コード例 #5
0
ファイル: trainer.py プロジェクト: templeblock/calamari
    def train(self, progress_bar=False):
        checkpoint_params = self.checkpoint_params

        train_start_time = time.time() + self.checkpoint_params.total_time

        self.dataset.load_samples(processes=1, progress_bar=progress_bar)
        datas, txts = self.dataset.train_samples(skip_empty=checkpoint_params.skip_invalid_gt)
        if len(datas) == 0:
            raise Exception("Empty dataset is not allowed. Check if the data is at the correct location")

        if self.validation_dataset:
            self.validation_dataset.load_samples(processes=1, progress_bar=progress_bar)
            validation_datas, validation_txts = self.validation_dataset.train_samples(skip_empty=checkpoint_params.skip_invalid_gt)
            if len(validation_datas) == 0:
                raise Exception("Validation dataset is empty. Provide valid validation data for early stopping.")
        else:
            validation_datas, validation_txts = [], []


        # preprocessing steps
        texts = self.txt_preproc.apply(txts, processes=checkpoint_params.processes, progress_bar=progress_bar)
        datas = self.data_preproc.apply(datas, processes=checkpoint_params.processes, progress_bar=progress_bar)
        validation_txts = self.txt_preproc.apply(validation_txts, processes=checkpoint_params.processes, progress_bar=progress_bar)
        validation_datas = self.data_preproc.apply(validation_datas, processes=checkpoint_params.processes, progress_bar=progress_bar)

        # compute the codec
        codec = self.codec if self.codec else Codec.from_texts(texts, whitelist=self.codec_whitelist)

        # data augmentation on preprocessed data
        if self.data_augmenter:
            datas, texts = self.data_augmenter.augment_datas(datas, texts, n_augmentations=self.n_augmentations,
                                                             processes=checkpoint_params.processes, progress_bar=progress_bar)

            # TODO: validation data augmentation
            # validation_datas, validation_txts = self.data_augmenter.augment_datas(validation_datas, validation_txts, n_augmentations=0,
            #                                                  processes=checkpoint_params.processes, progress_bar=progress_bar)

        # create backend
        network_params = checkpoint_params.model.network
        network_params.features = checkpoint_params.model.line_height
        network_params.classes = len(codec)
        if self.weights:
            # if we load the weights, take care of codec changes as-well
            with open(self.weights + '.json', 'r') as f:
                restore_checkpoint_params = json_format.Parse(f.read(), CheckpointParams())
                restore_model_params = restore_checkpoint_params.model

            # checks
            if checkpoint_params.model.line_height != network_params.features:
                raise Exception("The model to restore has a line height of {} but a line height of {} is requested".format(
                    network_params.features, checkpoint_params.model.line_height
                ))

            # create codec of the same type
            restore_codec = codec.__class__(restore_model_params.codec.charset)
            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(codec)
            codec = restore_codec
            print("Codec changes: {} deletions, {} appends".format(len(codec_changes[0]), len(codec_changes[1])))
            # The actual weight/bias matrix will be changed after loading the old weights
        else:
            codec_changes = None

        # store the new codec
        checkpoint_params.model.codec.charset[:] = codec.charset
        print("CODEC: {}".format(codec.charset))

        # compute the labels with (new/current) codec
        labels = [codec.encode(txt) for txt in texts]

        backend = create_backend_from_proto(network_params,
                                            weights=self.weights,
                                            )
        backend.set_train_data(datas, labels)
        backend.set_prediction_data(validation_datas)
        if codec_changes:
            backend.realign_model_labels(*codec_changes)
        backend.prepare(train=True)

        loss_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.loss_stats)
        ler_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.ler_stats)
        dt_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.dt_stats)

        early_stopping_enabled = self.validation_dataset is not None \
                                 and checkpoint_params.early_stopping_frequency > 0 \
                                 and checkpoint_params.early_stopping_nbest > 1
        early_stopping_best_accuracy = checkpoint_params.early_stopping_best_accuracy
        early_stopping_best_cur_nbest = checkpoint_params.early_stopping_best_cur_nbest
        early_stopping_best_at_iter = checkpoint_params.early_stopping_best_at_iter

        early_stopping_predictor = Predictor(codec=codec, text_postproc=self.txt_postproc,
                                             backend=backend)

        # Start the actual training
        # ====================================================================================

        iter = checkpoint_params.iter

        # helper function to write a checkpoint
        def make_checkpoint(base_dir, prefix, version=None):
            if version:
                checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{}.ckpt".format(prefix, version)))
            else:
                checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{:08d}.ckpt".format(prefix, iter + 1)))
            print("Storing checkpoint to '{}'".format(checkpoint_path))
            backend.save_checkpoint(checkpoint_path)
            checkpoint_params.iter = iter
            checkpoint_params.loss_stats[:] = loss_stats.values
            checkpoint_params.ler_stats[:] = ler_stats.values
            checkpoint_params.dt_stats[:] = dt_stats.values
            checkpoint_params.total_time = time.time() - train_start_time
            checkpoint_params.early_stopping_best_accuracy = early_stopping_best_accuracy
            checkpoint_params.early_stopping_best_cur_nbest = early_stopping_best_cur_nbest
            checkpoint_params.early_stopping_best_at_iter = early_stopping_best_at_iter

            with open(checkpoint_path + ".json", 'w') as f:
                f.write(json_format.MessageToJson(checkpoint_params))

            return checkpoint_path

        try:
            last_checkpoint = None

            # Training loop, can be interrupted by early stopping
            for iter in range(iter, checkpoint_params.max_iters):
                checkpoint_params.iter = iter

                iter_start_time = time.time()
                result = backend.train_step(checkpoint_params.batch_size)

                if not np.isfinite(result['loss']):
                    print("Error: Loss is not finite! Trying to restart from last checkpoint.")
                    if not last_checkpoint:
                        raise Exception("No checkpoint written yet. Training must be stopped.")
                    else:
                        # reload also non trainable weights, such as solver-specific variables
                        backend.load_checkpoint_weights(last_checkpoint, restore_only_trainable=False)
                        continue

                loss_stats.push(result['loss'])
                ler_stats.push(result['ler'])

                dt_stats.push(time.time() - iter_start_time)

                if iter % checkpoint_params.display == 0:
                    pred_sentence = self.txt_postproc.apply("".join(codec.decode(result["decoded"][0])))
                    gt_sentence = self.txt_postproc.apply("".join(codec.decode(result["gt"][0])))
                    print("#{:08d}: loss={:.8f} ler={:.8f} dt={:.8f}s".format(iter, loss_stats.mean(), ler_stats.mean(), dt_stats.mean()))
                    print(" PRED: '{}'".format(pred_sentence))
                    print(" TRUE: '{}'".format(gt_sentence))

                if (iter + 1) % checkpoint_params.checkpoint_frequency == 0:
                    last_checkpoint = make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix)

                if early_stopping_enabled and (iter + 1) % checkpoint_params.early_stopping_frequency == 0:
                    print("Checking early stopping model")

                    out = early_stopping_predictor.predict_raw(validation_datas, batch_size=checkpoint_params.batch_size,
                                                               progress_bar=progress_bar, apply_preproc=False)
                    pred_texts = [d.sentence for d in out]
                    result = Evaluator.evaluate(gt_data=validation_txts, pred_data=pred_texts, progress_bar=progress_bar)
                    accuracy = 1 - result["avg_ler"]

                    if accuracy > early_stopping_best_accuracy:
                        early_stopping_best_accuracy = accuracy
                        early_stopping_best_cur_nbest = 1
                        early_stopping_best_at_iter = iter + 1
                        # overwrite as best model
                        last_checkpoint = make_checkpoint(
                            checkpoint_params.early_stopping_best_model_output_dir,
                            prefix="",
                            version=checkpoint_params.early_stopping_best_model_prefix,
                        )
                        print("Found better model with accuracy of {:%}".format(early_stopping_best_accuracy))
                    else:
                        early_stopping_best_cur_nbest += 1
                        print("No better model found. Currently accuracy of {:%} at iter {} (remaining nbest = {})".
                              format(early_stopping_best_accuracy, early_stopping_best_at_iter,
                                     checkpoint_params.early_stopping_nbest - early_stopping_best_cur_nbest))

                    if accuracy > 0 and early_stopping_best_cur_nbest >= checkpoint_params.early_stopping_nbest:
                        print("Early stopping now.")
                        break

        except KeyboardInterrupt as e:
            print("Storing interrupted checkpoint")
            make_checkpoint(checkpoint_params.output_dir,
                            checkpoint_params.output_model_prefix,
                            "interrupted")
            raise e

        print("Total time {}s for {} iterations.".format(time.time() - train_start_time, iter))
コード例 #6
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument(
        "--gt",
        nargs="+",
        required=True,
        help="Ground truth files (.gt.txt extension). "
        "Optionally, you can pass a single json file defining all parameters.")
    parser.add_argument(
        "--pred",
        nargs="+",
        default=None,
        help=
        "Prediction files if provided. Else files with .pred.txt are expected at the same "
        "location as the gt.")
    parser.add_argument("--pred_dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument("--pred_ext",
                        type=str,
                        default=".pred.txt",
                        help="Extension of the predicted text files")
    parser.add_argument(
        "--n_confusions",
        type=int,
        default=10,
        help=
        "Only print n most common confusions. Defaults to 10, use -1 for all.")
    parser.add_argument(
        "--n_worst_lines",
        type=int,
        default=0,
        help="Print the n worst recognized text lines with its error")
    parser.add_argument(
        "--xlsx_output",
        type=str,
        help="Optionally write a xlsx file with the evaluation results")
    parser.add_argument("--num_threads",
                        type=int,
                        default=1,
                        help="Number of threads to use for evaluation")
    parser.add_argument(
        "--non_existing_file_handling_mode",
        type=str,
        default="error",
        help=
        "How to handle non existing .pred.txt files. Possible modes: skip, empty, error. "
        "'Skip' will simply skip the evaluation of that file (not counting it to errors). "
        "'Empty' will handle this file as would it be empty (fully checking for errors)."
        "'Error' will throw an exception if a file is not existing. This is the default behaviour."
    )
    parser.add_argument("--skip_empty_gt",
                        action="store_true",
                        default=False,
                        help="Ignore lines of the gt that are empty.")
    parser.add_argument("--no_progress_bars",
                        action="store_true",
                        help="Do not show any progress bars")
    parser.add_argument(
        "--checkpoint",
        type=str,
        default=None,
        help=
        "Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)"
    )

    # page xml specific args
    parser.add_argument("--pagexml_gt_text_index", default=0)
    parser.add_argument("--pagexml_pred_text_index", default=1)

    args = parser.parse_args()

    # check if loading a json file
    if len(args.gt) == 1 and args.gt[0].endswith("json"):
        with open(args.gt[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    print("Resolving files")
    gt_files = sorted(glob_all(args.gt))

    if args.pred:
        pred_files = sorted(glob_all(args.pred))
    else:
        pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files]
        args.pred_dataset = args.dataset

    if args.non_existing_file_handling_mode.lower() == "skip":
        non_existing_pred = [p for p in pred_files if not os.path.exists(p)]
        for f in non_existing_pred:
            idx = pred_files.index(f)
            del pred_files[idx]
            del gt_files[idx]

    text_preproc = None
    if args.checkpoint:
        with open(
                args.checkpoint if args.checkpoint.endswith(".json") else
                args.checkpoint + '.json', 'r') as f:
            checkpoint_params = json_format.Parse(f.read(), CheckpointParams())
            text_preproc = text_processor_from_proto(
                checkpoint_params.model.text_preprocessor)

    non_existing_as_empty = args.non_existing_file_handling_mode.lower(
    ) != "error "
    gt_data_set = create_dataset(
        args.dataset,
        DataSetMode.EVAL,
        texts=gt_files,
        non_existing_as_empty=non_existing_as_empty,
        args={'text_index': args.pagexml_gt_text_index},
    )
    pred_data_set = create_dataset(
        args.pred_dataset,
        DataSetMode.EVAL,
        texts=pred_files,
        non_existing_as_empty=non_existing_as_empty,
        args={'text_index': args.pagexml_pred_text_index},
    )

    evaluator = Evaluator(text_preprocessor=text_preproc,
                          skip_empty_gt=args.skip_empty_gt)
    r = evaluator.run(gt_dataset=gt_data_set,
                      pred_dataset=pred_data_set,
                      processes=args.num_threads,
                      progress_bar=not args.no_progress_bars)

    # TODO: More output
    print("Evaluation result")
    print("=================")
    print("")
    print(
        "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)"
        .format(r["avg_ler"], r["total_char_errs"], r["total_chars"],
                r["total_sync_errs"]))

    # sort descending
    print_confusions(r, args.n_confusions)

    print_worst_lines(r, gt_data_set.samples(), args.n_worst_lines)

    if args.xlsx_output:
        write_xlsx(args.xlsx_output, [{
            "prefix": "evaluation",
            "results": r,
            "gt_files": gt_files,
        }])
コード例 #7
0
def run(cfg: CfgNode):

    # check if loading a json file
    if len(cfg.DATASET.TRAIN.PATH) == 1 and cfg.DATASET.TRAIN.PATH[0].endswith(
            "json"):
        import json
        with open(cfg.DATASET.TRAIN.PATH[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                if key == 'dataset' or key == 'validation_dataset':
                    setattr(cfg, key, DataSetType.from_string(value))
                else:
                    setattr(cfg, key, value)

    # parse whitelist
    whitelist = cfg.MODEL.CODEX.WHITELIST
    if len(whitelist) == 1:
        whitelist = list(whitelist[0])

    whitelist_files = glob_all(cfg.MODEL.CODEX.WHITELIST_FILES)
    for f in whitelist_files:
        with open(f) as txt:
            whitelist += list(txt.read())

    if cfg.DATASET.TRAIN.GT_EXTENSION is False:
        cfg.DATASET.TRAIN.GT_EXTENSION = DataSetType.gt_extension(
            cfg.DATASET.TRAIN.TYPE)

    if cfg.DATASET.VALID.GT_EXTENSION is False:
        cfg.DATASET.VALID.GT_EXTENSION = DataSetType.gt_extension(
            cfg.DATASET.VALID.TYPE)

    text_generator_params = TextGeneratorParameters()

    line_generator_params = LineGeneratorParameters()

    dataset_args = {
        'line_generator_params': line_generator_params,
        'text_generator_params': text_generator_params,
        'pad': None,
        'text_index': 0,
    }

    # Training dataset
    dataset = create_train_dataset(cfg, dataset_args)

    # Validation dataset
    validation_dataset_list = create_test_dataset(cfg, dataset_args)

    params = CheckpointParams()

    params.max_iters = cfg.SOLVER.MAX_ITER
    params.stats_size = cfg.STATS_SIZE
    params.batch_size = cfg.SOLVER.BATCH_SIZE
    params.checkpoint_frequency = cfg.SOLVER.CHECKPOINT_FREQ if cfg.SOLVER.CHECKPOINT_FREQ >= 0 else cfg.SOLVER.EARLY_STOPPING_FREQ
    params.output_dir = cfg.OUTPUT_DIR
    params.output_model_prefix = cfg.OUTPUT_MODEL_PREFIX
    params.display = cfg.DISPLAY
    params.skip_invalid_gt = not cfg.DATALOADER.NO_SKIP_INVALID_GT
    params.processes = cfg.NUM_THREADS
    params.data_aug_retrain_on_original = not cfg.DATALOADER.ONLY_TRAIN_ON_AUGMENTED

    params.early_stopping_at_acc = cfg.SOLVER.EARLY_STOPPING_AT_ACC
    params.early_stopping_frequency = cfg.SOLVER.EARLY_STOPPING_FREQ
    params.early_stopping_nbest = cfg.SOLVER.EARLY_STOPPING_NBEST
    params.early_stopping_best_model_prefix = cfg.EARLY_STOPPING_BEST_MODEL_PREFIX
    params.early_stopping_best_model_output_dir = \
        cfg.EARLY_STOPPING_BEST_MODEL_OUTPUT_DIR if cfg.EARLY_STOPPING_BEST_MODEL_OUTPUT_DIR else cfg.OUTPUT_DIR

    if cfg.INPUT.DATA_PREPROCESSING is False or len(
            cfg.INPUT.DATA_PREPROCESSING) == 0:
        cfg.INPUT.DATA_PREPROCESSING = [
            DataPreprocessorParams.DEFAULT_NORMALIZER
        ]

    params.model.data_preprocessor.type = DataPreprocessorParams.MULTI_NORMALIZER
    for preproc in cfg.INPUT.DATA_PREPROCESSING:
        pp = params.model.data_preprocessor.children.add()
        pp.type = DataPreprocessorParams.Type.Value(preproc) if isinstance(
            preproc, str) else preproc
        pp.line_height = cfg.INPUT.LINE_HEIGHT
        pp.pad = cfg.INPUT.PAD

    # Text pre processing (reading)
    params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(
        params.model.text_preprocessor.children.add(),
        default=cfg.INPUT.TEXT_NORMALIZATION)
    default_text_regularizer_params(
        params.model.text_preprocessor.children.add(),
        groups=cfg.INPUT.TEXT_REGULARIZATION)
    strip_processor_params = params.model.text_preprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    # Text post processing (prediction)
    params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(
        params.model.text_postprocessor.children.add(),
        default=cfg.INPUT.TEXT_NORMALIZATION)
    default_text_regularizer_params(
        params.model.text_postprocessor.children.add(),
        groups=cfg.INPUT.TEXT_REGULARIZATION)
    strip_processor_params = params.model.text_postprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    if cfg.SEED > 0:
        params.model.network.backend.random_seed = cfg.SEED

    if cfg.INPUT.BIDI_DIR:
        # change bidirectional text direction if desired
        bidi_dir_to_enum = {
            "rtl": TextProcessorParams.BIDI_RTL,
            "ltr": TextProcessorParams.BIDI_LTR,
            "auto": TextProcessorParams.BIDI_AUTO
        }

        bidi_processor_params = params.model.text_preprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = bidi_dir_to_enum[
            cfg.INPUT.BIDI_DIR]

        bidi_processor_params = params.model.text_postprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO

    params.model.line_height = cfg.INPUT.LINE_HEIGHT
    params.model.network.learning_rate = cfg.SOLVER.LR
    params.model.network.lr_decay = cfg.SOLVER.LR_DECAY
    params.model.network.lr_decay_freq = cfg.SOLVER.LR_DECAY_FREQ
    params.model.network.train_last_n_layer = cfg.SOLVER.TRAIN_LAST_N_LAYER
    network_params_from_definition_string(cfg.MODEL.NETWORK,
                                          params.model.network)
    params.model.network.clipping_norm = cfg.SOLVER.GRADIENT_CLIPPING_NORM
    params.model.network.backend.num_inter_threads = 0
    params.model.network.backend.num_intra_threads = 0
    params.model.network.backend.shuffle_buffer_size = cfg.DATALOADER.SHUFFLE_BUFFER_SIZE

    if cfg.MODEL.WEIGHTS == "":
        weights = None
    else:
        weights = cfg.MODEL.WEIGHTS

    # create the actual trainer
    trainer = Trainer(
        params,
        dataset,
        validation_dataset=validation_dataset_list,
        data_augmenter=SimpleDataAugmenter(),
        n_augmentations=cfg.INPUT.N_AUGMENT,
        weights=weights,
        codec_whitelist=whitelist,
        keep_loaded_codec=cfg.MODEL.CODEX.KEEP_LOADED_CODEC,
        preload_training=not cfg.DATALOADER.TRAIN_ON_THE_FLY,
        preload_validation=not cfg.DATALOADER.VALID_ON_THE_FLY,
    )
    trainer.train(auto_compute_codec=not cfg.MODEL.CODEX.SEE_WHITELIST,
                  progress_bar=not cfg.NO_PROGRESS_BAR)
コード例 #8
0
ファイル: nashi_client.py プロジェクト: stweil/nashi
def params_from_args(args):
    """
    Turn args to calamari into params
    """
    params = CheckpointParams()
    params.max_iters = args.max_iters
    params.stats_size = args.stats_size
    params.batch_size = args.batch_size
    params.checkpoint_frequency = args.checkpoint_frequency if args.checkpoint_frequency >= 0 else args.early_stopping_frequency
    params.output_dir = args.output_dir
    params.output_model_prefix = args.output_model_prefix
    params.display = args.display
    params.skip_invalid_gt = not args.no_skip_invalid_gt
    params.processes = args.num_threads
    params.data_aug_retrain_on_original = not args.only_train_on_augmented

    params.early_stopping_at_acc = args.early_stopping_at_accuracy
    params.early_stopping_frequency = args.early_stopping_frequency
    params.early_stopping_nbest = args.early_stopping_nbest
    params.early_stopping_best_model_prefix = args.early_stopping_best_model_prefix
    params.early_stopping_best_model_output_dir = \
        args.early_stopping_best_model_output_dir if args.early_stopping_best_model_output_dir else args.output_dir

    params.model.data_preprocessor.type = DataPreprocessorParams.DEFAULT_NORMALIZER
    params.model.data_preprocessor.line_height = args.line_height
    params.model.data_preprocessor.pad = args.pad

    # Text pre processing (reading)
    params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(
        params.model.text_preprocessor.children.add(),
        default=args.text_normalization)
    default_text_regularizer_params(
        params.model.text_preprocessor.children.add(),
        groups=args.text_regularization)
    strip_processor_params = params.model.text_preprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    # Text post processing (prediction)
    params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(
        params.model.text_postprocessor.children.add(),
        default=args.text_normalization)
    default_text_regularizer_params(
        params.model.text_postprocessor.children.add(),
        groups=args.text_regularization)
    strip_processor_params = params.model.text_postprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    if args.seed > 0:
        params.model.network.backend.random_seed = args.seed

    if args.bidi_dir:
        # change bidirectional text direction if desired
        bidi_dir_to_enum = {
            "rtl": TextProcessorParams.BIDI_RTL,
            "ltr": TextProcessorParams.BIDI_LTR,
            "auto": TextProcessorParams.BIDI_AUTO
        }

        bidi_processor_params = params.model.text_preprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir]

        bidi_processor_params = params.model.text_postprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO

    params.model.line_height = args.line_height

    network_params_from_definition_string(args.network, params.model.network)
    params.model.network.clipping_norm = args.gradient_clipping_norm
    params.model.network.backend.num_inter_threads = args.num_inter_threads
    params.model.network.backend.num_intra_threads = args.num_intra_threads
    params.model.network.backend.shuffle_buffer_size = args.shuffle_buffer_size

    params.early_stopping_at_acc = args.early_stopping_at_accuracy

    return params
コード例 #9
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--checkpoint",
                        type=str,
                        required=True,
                        help="The checkpoint used to resume")

    # validation files
    parser.add_argument("--validation",
                        type=str,
                        nargs="+",
                        help="Validation line files used for early stopping")
    parser.add_argument(
        "--validation_text_files",
        nargs="+",
        default=None,
        help=
        "Optional list of validation GT files if they are in other directory")
    parser.add_argument(
        "--validation_extension",
        default=None,
        help="Default extension of the gt files (expected to exist in same dir)"
    )
    parser.add_argument("--validation_dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)

    # input files
    parser.add_argument(
        "--files",
        nargs="+",
        help=
        "List all image files that shall be processed. Ground truth fils with the same "
        "base name but with '.gt.txt' as extension are required at the same location"
    )
    parser.add_argument(
        "--text_files",
        nargs="+",
        default=None,
        help="Optional list of GT files if they are in other directory")
    parser.add_argument(
        "--gt_extension",
        default=None,
        help="Default extension of the gt files (expected to exist in same dir)"
    )
    parser.add_argument("--dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument(
        "--no_skip_invalid_gt",
        action="store_true",
        help="Do no skip invalid gt, instead raise an exception.")

    args = parser.parse_args()

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    if args.validation_extension is None:
        args.validation_extension = DataSetType.gt_extension(
            args.validation_dataset)

    # Training dataset
    print("Resolving input files")
    input_image_files = sorted(glob_all(args.files))
    if not args.text_files:
        gt_txt_files = [
            split_all_ext(f)[0] + args.gt_extension for f in input_image_files
        ]
    else:
        gt_txt_files = sorted(glob_all(args.text_files))
        input_image_files, gt_txt_files = keep_files_with_same_file_name(
            input_image_files, gt_txt_files)
        for img, gt in zip(input_image_files, gt_txt_files):
            if split_all_ext(os.path.basename(img))[0] != split_all_ext(
                    os.path.basename(gt))[0]:
                raise Exception(
                    "Expected identical basenames of file: {} and {}".format(
                        img, gt))

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception(
            "Some image are occurring more than once in the data set.")

    dataset = create_dataset(args.dataset,
                             DataSetMode.TRAIN,
                             images=input_image_files,
                             texts=gt_txt_files,
                             skip_invalid=not args.no_skip_invalid_gt)
    print("Found {} files in the dataset".format(len(dataset)))

    # Validation dataset
    if args.validation:
        print("Resolving validation files")
        validation_image_files = glob_all(args.validation)
        if not args.validation_text_files:
            val_txt_files = [
                split_all_ext(f)[0] + args.validation_extension
                for f in validation_image_files
            ]
        else:
            val_txt_files = sorted(glob_all(args.validation_text_files))
            validation_image_files, val_txt_files = keep_files_with_same_file_name(
                validation_image_files, val_txt_files)
            for img, gt in zip(validation_image_files, val_txt_files):
                if split_all_ext(os.path.basename(img))[0] != split_all_ext(
                        os.path.basename(gt))[0]:
                    raise Exception(
                        "Expected identical basenames of validation file: {} and {}"
                        .format(img, gt))

        if len(set(val_txt_files)) != len(val_txt_files):
            raise Exception(
                "Some validation images are occurring more than once in the data set."
            )

        validation_dataset = create_dataset(
            args.validation_dataset,
            DataSetMode.TRAIN,
            images=validation_image_files,
            texts=val_txt_files,
            skip_invalid=not args.no_skip_invalid_gt)
        print("Found {} files in the validation dataset".format(
            len(validation_dataset)))
    else:
        validation_dataset = None

    print("Resuming training")
    with open(args.checkpoint + '.json', 'r') as f:
        checkpoint_params = json_format.Parse(f.read(), CheckpointParams())

        trainer = Trainer(checkpoint_params,
                          dataset,
                          validation_dataset=validation_dataset,
                          weights=args.checkpoint)
        trainer.train(progress_bar=True)
コード例 #10
0
ファイル: train.py プロジェクト: subrahmanyap/calamari
def run(args):

    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        import json
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                if key == 'dataset' or key == 'validation_dataset':
                    setattr(args, key, DataSetType.from_string(value))
                else:
                    setattr(args, key, value)

    # parse whitelist
    whitelist = args.whitelist
    if len(whitelist) == 1:
        whitelist = list(whitelist[0])

    whitelist_files = glob_all(args.whitelist_files)
    for f in whitelist_files:
        with open(f) as txt:
            whitelist += list(txt.read())

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    if args.validation_extension is None:
        args.validation_extension = DataSetType.gt_extension(args.validation_dataset)

    if args.text_generator_params is not None:
        with open(args.text_generator_params, 'r') as f:
            args.text_generator_params = json_format.Parse(f.read(), TextGeneratorParameters())
    else:
        args.text_generator_params = TextGeneratorParameters()

    if args.line_generator_params is not None:
        with open(args.line_generator_params, 'r') as f:
            args.line_generator_params = json_format.Parse(f.read(), LineGeneratorParameters())
    else:
        args.line_generator_params = LineGeneratorParameters()

    dataset_args = {
        'line_generator_params': args.line_generator_params,
        'text_generator_params': args.text_generator_params,
        'pad': args.dataset_pad,
        'text_index': args.pagexml_text_index,
    }

    # Training dataset
    dataset = create_train_dataset(args, dataset_args)

    # Validation dataset
    if args.validation:
        print("Resolving validation files")
        validation_image_files = glob_all(args.validation)
        if not args.validation_text_files:
            val_txt_files = [split_all_ext(f)[0] + args.validation_extension for f in validation_image_files]
        else:
            val_txt_files = sorted(glob_all(args.validation_text_files))
            validation_image_files, val_txt_files = keep_files_with_same_file_name(validation_image_files, val_txt_files)
            for img, gt in zip(validation_image_files, val_txt_files):
                if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                    raise Exception("Expected identical basenames of validation file: {} and {}".format(img, gt))

        if len(set(val_txt_files)) != len(val_txt_files):
            raise Exception("Some validation images are occurring more than once in the data set.")

        validation_dataset = create_dataset(
            args.validation_dataset,
            DataSetMode.TRAIN,
            images=validation_image_files,
            texts=val_txt_files,
            skip_invalid=not args.no_skip_invalid_gt,
            args=dataset_args,
        )
        print("Found {} files in the validation dataset".format(len(validation_dataset)))
    else:
        validation_dataset = None

    params = CheckpointParams()

    params.max_iters = args.max_iters
    params.stats_size = args.stats_size
    params.batch_size = args.batch_size
    params.checkpoint_frequency = args.checkpoint_frequency if args.checkpoint_frequency >= 0 else args.early_stopping_frequency
    params.output_dir = args.output_dir
    params.output_model_prefix = args.output_model_prefix
    params.display = args.display
    params.skip_invalid_gt = not args.no_skip_invalid_gt
    params.processes = args.num_threads
    params.data_aug_retrain_on_original = not args.only_train_on_augmented

    params.early_stopping_frequency = args.early_stopping_frequency
    params.early_stopping_nbest = args.early_stopping_nbest
    params.early_stopping_best_model_prefix = args.early_stopping_best_model_prefix
    params.early_stopping_best_model_output_dir = \
        args.early_stopping_best_model_output_dir if args.early_stopping_best_model_output_dir else args.output_dir

    if args.data_preprocessing is None or len(args.data_preprocessing) == 0:
        args.data_preprocessing = [DataPreprocessorParams.DEFAULT_NORMALIZER]

    params.model.data_preprocessor.type = DataPreprocessorParams.MULTI_NORMALIZER
    for preproc in args.data_preprocessing:
        pp = params.model.data_preprocessor.children.add()
        pp.type = DataPreprocessorParams.Type.Value(preproc) if isinstance(preproc, str) else preproc
        pp.line_height = args.line_height
        pp.pad = args.pad

    # Text pre processing (reading)
    params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(params.model.text_preprocessor.children.add(), default=args.text_normalization)
    default_text_regularizer_params(params.model.text_preprocessor.children.add(), groups=args.text_regularization)
    strip_processor_params = params.model.text_preprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    # Text post processing (prediction)
    params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(params.model.text_postprocessor.children.add(), default=args.text_normalization)
    default_text_regularizer_params(params.model.text_postprocessor.children.add(), groups=args.text_regularization)
    strip_processor_params = params.model.text_postprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    if args.seed > 0:
        params.model.network.backend.random_seed = args.seed

    if args.bidi_dir:
        # change bidirectional text direction if desired
        bidi_dir_to_enum = {"rtl": TextProcessorParams.BIDI_RTL, "ltr": TextProcessorParams.BIDI_LTR,
                            "auto": TextProcessorParams.BIDI_AUTO}

        bidi_processor_params = params.model.text_preprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir]

        bidi_processor_params = params.model.text_postprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO

    params.model.line_height = args.line_height

    network_params_from_definition_string(args.network, params.model.network)
    params.model.network.clipping_mode = NetworkParams.ClippingMode.Value("CLIP_" + args.gradient_clipping_mode.upper())
    params.model.network.clipping_constant = args.gradient_clipping_const
    params.model.network.backend.fuzzy_ctc_library_path = args.fuzzy_ctc_library_path
    params.model.network.backend.num_inter_threads = args.num_inter_threads
    params.model.network.backend.num_intra_threads = args.num_intra_threads
    params.model.network.backend.shuffle_buffer_size = args.shuffle_buffer_size

    # create the actual trainer
    trainer = Trainer(params,
                      dataset,
                      validation_dataset=validation_dataset,
                      data_augmenter=SimpleDataAugmenter(),
                      n_augmentations=args.n_augmentations,
                      weights=args.weights,
                      codec_whitelist=whitelist,
                      keep_loaded_codec=args.keep_loaded_codec,
                      preload_training=not args.train_data_on_the_fly,
                      preload_validation=not args.validation_data_on_the_fly,
                      )
    trainer.train(
        auto_compute_codec=not args.no_auto_compute_codec,
        progress_bar=not args.no_progress_bars
    )
コード例 #11
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--gt",
                        nargs="+",
                        required=True,
                        help="Ground truth files (.gt.txt extension)")
    parser.add_argument(
        "--pred",
        nargs="+",
        default=None,
        help=
        "Prediction files if provided. Else files with .pred.txt are expected at the same "
        "location as the gt.")
    parser.add_argument("--pred_ext",
                        type=str,
                        default=".pred.txt",
                        help="Extension of the predicted text files")
    parser.add_argument(
        "--n_confusions",
        type=int,
        default=10,
        help=
        "Only print n most common confusions. Defaults to 10, use -1 for all.")
    parser.add_argument(
        "--n_worst_lines",
        type=int,
        default=0,
        help="Print the n worst recognized text lines with its error")
    parser.add_argument(
        "--xlsx_output",
        type=str,
        help="Optionally write a xlsx file with the evaluation results")
    parser.add_argument("--num_threads",
                        type=int,
                        default=1,
                        help="Number of threads to use for evaluation")
    parser.add_argument(
        "--non_existing_file_handling_mode",
        type=str,
        default="error",
        help=
        "How to handle non existing .pred.txt files. Possible modes: skip, empty, error. "
        "'Skip' will simply skip the evaluation of that file (not counting it to errors). "
        "'Empty' will handle this file as would it be empty (fully checking for errors)."
        "'Error' will throw an exception if a file is not existing. This is the default behaviour."
    )
    parser.add_argument("--no_progress_bars",
                        action="store_true",
                        help="Do not show any progress bars")
    parser.add_argument(
        "--checkpoint",
        type=str,
        default=None,
        help=
        "Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)"
    )

    args = parser.parse_args()

    print("Resolving files")
    gt_files = sorted(glob_all(args.gt))

    if args.pred:
        pred_files = sorted(glob_all(args.pred))
        if len(pred_files) != len(gt_files):
            raise Exception(
                "Mismatch in the number of gt and pred files: {} vs {}".format(
                    len(gt_files), len(pred_files)))
    else:
        pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files]

    if args.non_existing_file_handling_mode.lower() == "skip":
        non_existing_pred = [p for p in pred_files if not os.path.exists(p)]
        for f in non_existing_pred:
            idx = pred_files.index(f)
            del pred_files[idx]
            del gt_files[idx]

    text_preproc = None
    if args.checkpoint:
        with open(
                args.checkpoint if args.checkpoint.endswith(".json") else
                args.checkpoint + '.json', 'r') as f:
            checkpoint_params = json_format.Parse(f.read(), CheckpointParams())
            text_preproc = text_processor_from_proto(
                checkpoint_params.model.text_preprocessor)

    non_existing_as_empty = args.non_existing_file_handling_mode.lower(
    ) == "empty"
    gt_data_set = FileDataSet(texts=gt_files,
                              non_existing_as_empty=non_existing_as_empty)
    pred_data_set = FileDataSet(texts=pred_files,
                                non_existing_as_empty=non_existing_as_empty)

    evaluator = Evaluator(text_preprocessor=text_preproc)
    r = evaluator.run(gt_dataset=gt_data_set,
                      pred_dataset=pred_data_set,
                      processes=args.num_threads,
                      progress_bar=not args.no_progress_bars)

    # TODO: More output
    print("Evaluation result")
    print("=================")
    print("")
    print(
        "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)"
        .format(r["avg_ler"], r["total_char_errs"], r["total_chars"],
                r["total_sync_errs"]))

    # sort descending
    print_confusions(r, args.n_confusions)

    print_worst_lines(r, gt_files, gt_data_set.text_samples(),
                      pred_data_set.text_samples(), args.n_worst_lines)

    if args.xlsx_output:
        write_xlsx(args.xlsx_output, [{
            "prefix": "evaluation",
            "results": r,
            "gt_files": gt_files,
            "gts": gt_data_set.text_samples(),
            "preds": pred_data_set.text_samples()
        }])
コード例 #12
0
ファイル: train.py プロジェクト: AIRob/calamari
def run(args):

    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        import json
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    # parse whitelist
    whitelist = args.whitelist
    if len(whitelist) == 1:
        whitelist = list(whitelist[0])

    whitelist_files = glob_all(args.whitelist_files)
    for f in whitelist_files:
        with open(f) as txt:
            whitelist += list(txt.read())

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    if args.validation_extension is None:
        args.validation_extension = DataSetType.gt_extension(args.validation_dataset)

    # Training dataset
    print("Resolving input files")
    input_image_files = sorted(glob_all(args.files))
    if not args.text_files:
        gt_txt_files = [split_all_ext(f)[0] + args.gt_extension for f in input_image_files]
    else:
        gt_txt_files = sorted(glob_all(args.text_files))
        input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files)
        for img, gt in zip(input_image_files, gt_txt_files):
            if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                raise Exception("Expected identical basenames of file: {} and {}".format(img, gt))

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception("Some image are occurring more than once in the data set.")

    dataset = create_dataset(
        args.dataset,
        DataSetMode.TRAIN,
        images=input_image_files,
        texts=gt_txt_files,
        skip_invalid=not args.no_skip_invalid_gt
    )
    print("Found {} files in the dataset".format(len(dataset)))

    # Validation dataset
    if args.validation:
        print("Resolving validation files")
        validation_image_files = glob_all(args.validation)
        if not args.validation_text_files:
            val_txt_files = [split_all_ext(f)[0] + args.validation_extension for f in validation_image_files]
        else:
            val_txt_files = sorted(glob_all(args.validation_text_files))
            validation_image_files, val_txt_files = keep_files_with_same_file_name(validation_image_files, val_txt_files)
            for img, gt in zip(validation_image_files, val_txt_files):
                if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                    raise Exception("Expected identical basenames of validation file: {} and {}".format(img, gt))

        if len(set(val_txt_files)) != len(val_txt_files):
            raise Exception("Some validation images are occurring more than once in the data set.")

        validation_dataset = create_dataset(
            args.validation_dataset,
            DataSetMode.TRAIN,
            images=validation_image_files,
            texts=val_txt_files,
            skip_invalid=not args.no_skip_invalid_gt)
        print("Found {} files in the validation dataset".format(len(validation_dataset)))
    else:
        validation_dataset = None

    params = CheckpointParams()

    params.max_iters = args.max_iters
    params.stats_size = args.stats_size
    params.batch_size = args.batch_size
    params.checkpoint_frequency = args.checkpoint_frequency if args.checkpoint_frequency >= 0 else args.early_stopping_frequency
    params.output_dir = args.output_dir
    params.output_model_prefix = args.output_model_prefix
    params.display = args.display
    params.skip_invalid_gt = not args.no_skip_invalid_gt
    params.processes = args.num_threads
    params.data_aug_retrain_on_original = not args.only_train_on_augmented

    params.early_stopping_frequency = args.early_stopping_frequency
    params.early_stopping_nbest = args.early_stopping_nbest
    params.early_stopping_best_model_prefix = args.early_stopping_best_model_prefix
    params.early_stopping_best_model_output_dir = \
        args.early_stopping_best_model_output_dir if args.early_stopping_best_model_output_dir else args.output_dir

    params.model.data_preprocessor.type = DataPreprocessorParams.DEFAULT_NORMALIZER
    params.model.data_preprocessor.line_height = args.line_height
    params.model.data_preprocessor.pad = args.pad

    # Text pre processing (reading)
    params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(params.model.text_preprocessor.children.add(), default=args.text_normalization)
    default_text_regularizer_params(params.model.text_preprocessor.children.add(), groups=args.text_regularization)
    strip_processor_params = params.model.text_preprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    # Text post processing (prediction)
    params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(params.model.text_postprocessor.children.add(), default=args.text_normalization)
    default_text_regularizer_params(params.model.text_postprocessor.children.add(), groups=args.text_regularization)
    strip_processor_params = params.model.text_postprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    if args.seed > 0:
        params.model.network.backend.random_seed = args.seed

    if args.bidi_dir:
        # change bidirectional text direction if desired
        bidi_dir_to_enum = {"rtl": TextProcessorParams.BIDI_RTL, "ltr": TextProcessorParams.BIDI_LTR,
                            "auto": TextProcessorParams.BIDI_AUTO}

        bidi_processor_params = params.model.text_preprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir]

        bidi_processor_params = params.model.text_postprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO

    params.model.line_height = args.line_height

    network_params_from_definition_string(args.network, params.model.network)
    params.model.network.clipping_mode = NetworkParams.ClippingMode.Value("CLIP_" + args.gradient_clipping_mode.upper())
    params.model.network.clipping_constant = args.gradient_clipping_const
    params.model.network.backend.fuzzy_ctc_library_path = args.fuzzy_ctc_library_path
    params.model.network.backend.num_inter_threads = args.num_inter_threads
    params.model.network.backend.num_intra_threads = args.num_intra_threads

    # create the actual trainer
    trainer = Trainer(params,
                      dataset,
                      validation_dataset=validation_dataset,
                      data_augmenter=SimpleDataAugmenter(),
                      n_augmentations=args.n_augmentations,
                      weights=args.weights,
                      codec_whitelist=whitelist,
                      preload_training=not args.train_data_on_the_fly,
                      preload_validation=not args.validation_data_on_the_fly,
                      )
    trainer.train(
        auto_compute_codec=not args.no_auto_compute_codec,
        progress_bar=not args.no_progress_bars
    )
コード例 #13
0
def main():
    parser = argparse.ArgumentParser()
    setup_train_args(parser)
    args = parser.parse_args()

    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        import json
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    # parse whitelist
    whitelist = args.whitelist
    whitelist_files = glob_all(args.whitelist_files)
    for f in whitelist_files:
        with open(f) as txt:
            whitelist += list(txt.read())

    # Training dataset
    print("Resolving input files")
    input_image_files = glob_all(args.files)
    gt_txt_files = [split_all_ext(f)[0] + ".gt.txt" for f in input_image_files]

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception("Some image are occurring more than once in the data set.")

    dataset = FileDataSet(input_image_files, gt_txt_files, skip_invalid=not args.no_skip_invalid_gt)
    print("Found {} files in the dataset".format(len(dataset)))

    # Validation dataset
    if args.validation:
        print("Resolving validation files")
        validation_image_files = glob_all(args.validation)
        val_txt_files = [split_all_ext(f)[0] + ".gt.txt" for f in validation_image_files]

        if len(set(val_txt_files)) != len(val_txt_files):
            raise Exception("Some validation images are occurring more than once in the data set.")

        validation_dataset = FileDataSet(validation_image_files, val_txt_files,
                                         skip_invalid=not args.no_skip_invalid_gt)
        print("Found {} files in the validation dataset".format(len(validation_dataset)))
    else:
        validation_dataset = None

    params = CheckpointParams()

    params.max_iters = args.max_iters
    params.stats_size = args.stats_size
    params.batch_size = args.batch_size
    params.checkpoint_frequency = args.checkpoint_frequency
    params.output_dir = args.output_dir
    params.output_model_prefix = args.output_model_prefix
    params.display = args.display
    params.skip_invalid_gt = not args.no_skip_invalid_gt
    params.processes = args.num_threads

    params.early_stopping_frequency = args.early_stopping_frequency if args.early_stopping_frequency >= 0 else args.checkpoint_frequency
    params.early_stopping_nbest = args.early_stopping_nbest
    params.early_stopping_best_model_prefix = args.early_stopping_best_model_prefix
    params.early_stopping_best_model_output_dir = \
        args.early_stopping_best_model_output_dir if args.early_stopping_best_model_output_dir else args.output_dir

    params.model.data_preprocessor.type = DataPreprocessorParams.DEFAULT_NORMALIZER
    params.model.data_preprocessor.line_height = args.line_height
    params.model.data_preprocessor.pad = args.pad

    # Text pre processing (reading)
    params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(params.model.text_preprocessor.children.add(), default=args.text_normalization)
    default_text_regularizer_params(params.model.text_preprocessor.children.add(), groups=args.text_regularization)
    strip_processor_params = params.model.text_preprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    # Text post processing (prediction)
    params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(params.model.text_postprocessor.children.add(), default=args.text_normalization)
    default_text_regularizer_params(params.model.text_postprocessor.children.add(), groups=args.text_regularization)
    strip_processor_params = params.model.text_postprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    if args.seed > 0:
        params.model.network.backend.random_seed = args.seed

    if args.bidi_dir:
        # change bidirectional text direction if desired
        bidi_dir_to_enum = {"rtl": TextProcessorParams.BIDI_RTL, "ltr": TextProcessorParams.BIDI_LTR}

        bidi_processor_params = params.model.text_preprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir]

        bidi_processor_params = params.model.text_postprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir]

    params.model.line_height = args.line_height

    network_params_from_definition_string(args.network, params.model.network)
    params.model.network.clipping_mode = NetworkParams.ClippingMode.Value("CLIP_" + args.gradient_clipping_mode.upper())
    params.model.network.clipping_constant = args.gradient_clipping_const
    params.model.network.backend.fuzzy_ctc_library_path = args.fuzzy_ctc_library_path
    params.model.network.backend.num_inter_threads = args.num_inter_threads
    params.model.network.backend.num_intra_threads = args.num_intra_threads

    # create the actual trainer
    trainer = Trainer(params,
                      dataset,
                      validation_dataset=validation_dataset,
                      data_augmenter=SimpleDataAugmenter(),
                      n_augmentations=args.n_augmentations,
                      weights=args.weights,
                      codec_whitelist=whitelist,
                      )
    trainer.train(progress_bar=not args.no_progress_bars)
コード例 #14
0
ファイル: trainer.py プロジェクト: hajicj/ommr4all-server
    def _train(self, target_book: Optional[DatabaseBook] = None, callback: Optional[TrainerCallback] = None):
        if callback:
            callback.resolving_files()
            calamari_callback = CalamariTrainerCallback(callback)
        else:
            calamari_callback = None

        train_dataset = self.train_dataset.to_text_line_calamari_dataset(train=True, callback=callback)
        val_dataset = self.validation_dataset.to_text_line_calamari_dataset(train=True, callback=callback)
        output = self.settings.model.path

        params = CheckpointParams()

        params.max_iters = self.params.n_iter
        params.stats_size = 1000
        params.batch_size = 1
        params.checkpoint_frequency = 0
        params.output_dir = output
        params.output_model_prefix = 'text'
        params.display = self.params.display
        params.skip_invalid_gt = True
        params.processes = 2
        params.data_aug_retrain_on_original = True

        params.early_stopping_at_acc = self.params.early_stopping_at_acc if self.params.early_stopping_at_acc else 0
        params.early_stopping_frequency = self.params.early_stopping_test_interval
        params.early_stopping_nbest = self.params.early_stopping_max_keep
        params.early_stopping_best_model_prefix = 'text_best'
        params.early_stopping_best_model_output_dir = output

        params.model.data_preprocessor.type = DataPreprocessorParams.DEFAULT_NORMALIZER
        params.model.data_preprocessor.pad = 5
        params.model.data_preprocessor.line_height = self.settings.dataset_params.height
        params.model.text_preprocessor.type = TextProcessorParams.NOOP_NORMALIZER
        params.model.text_postprocessor.type = TextProcessorParams.NOOP_NORMALIZER

        params.model.line_height = self.settings.dataset_params.height

        network_str = self.settings.calamari_params.network
        if self.params.l_rate > 0:
            network_str += ',learning_rate={}'.format(self.params.l_rate)

        if self.settings.calamari_params.n_folds > 1:
            train_args = {
                "max_iters": params.max_iters,
                "stats_size": params.stats_size,
                "checkpoint_frequency": params.checkpoint_frequency,
                "pad": 0,
                "network": network_str,
                "early-stopping_at_accuracy": params.early_stopping_at_acc,
                "early_stopping_frequency": params.early_stopping_frequency,
                "early_stopping_nbest": params.early_stopping_nbest,
                "line_height": params.model.line_height,
                "data_preprocessing": ["RANGE_NORMALIZER", "FINAL_PREPARATION"],
            }
            trainer = CrossFoldTrainer(
                self.settings.calamari_params.n_folds, train_dataset,
                output, 'omr_best_{id}', train_args, progress_bars=True
            )
            temporary_dir = os.path.join(output, "temporary_dir")
            trainer.run(
                self.settings.calamari_params.single_folds,
                temporary_dir=temporary_dir,
                spawn_subprocesses=False, max_parallel_models=1,    # Force to run in same scope as parent process
            )
        else:
            network_params_from_definition_string(network_str, params.model.network)
            trainer = Trainer(
                codec_whitelist='abcdefghijklmnopqrstuvwxyz ',      # Always keep space and all letters
                checkpoint_params=params,
                dataset=train_dataset,
                validation_dataset=val_dataset,
                n_augmentations=self.params.data_augmentation_factor if self.params.data_augmentation_factor else 0,
                data_augmenter=SimpleDataAugmenter(),
                weights=None if not self.params.model_to_load() else self.params.model_to_load().local_file('text_best.ckpt'),
                preload_training=True,
                preload_validation=True,
            )
            trainer.train(training_callback=calamari_callback,
                          auto_compute_codec=True,
                          )
コード例 #15
0
ファイル: predictor.py プロジェクト: zzmcdc/calamari
    def __init__(self,
                 checkpoint=None,
                 text_postproc=None,
                 data_preproc=None,
                 codec=None,
                 network=None,
                 batch_size=1,
                 processes=1):
        """ Predicting a dataset based on a trained model

        Parameters
        ----------
        checkpoint : str, optional
            filepath of the checkpoint of the network to load, alternatively you can directly use a loaded `network`
        text_postproc : TextProcessor, optional
            text processor to be applied on the predicted sentence for the final output.
            If loaded from a checkpoint the text processor will be loaded from it.
        data_preproc : DataProcessor, optional
            data processor (must be the same as of the trained model) to be applied to the input image.
            If loaded from a checkpoint the text processor will be loaded from it.
        codec : Codec, optional
            Codec of the deep net to use for decoding. This parameter is only required if a custom codec is used,
            or a `network` has been provided instead of a `checkpoint`
        network : ModelInterface, optional
            DNN instance to used. Alternatively you can provide a `checkpoint` to load a network.
        batch_size : int, optional
            Batch size to use for prediction
        processes : int, optional
            The number of processes to use for prediction
        """
        self.network = network
        self.checkpoint = checkpoint
        self.processes = processes

        if checkpoint:
            if network:
                raise Exception(
                    "Either a checkpoint or a network can be provided")

            with open(checkpoint + '.json', 'r') as f:
                checkpoint_params = json_format.Parse(f.read(),
                                                      CheckpointParams())
                self.model_params = checkpoint_params.model

            self.network_params = self.model_params.network
            backend = create_backend_from_proto(self.network_params,
                                                restore=self.checkpoint,
                                                processes=processes)
            self.network = backend.create_net(restore=self.checkpoint,
                                              weights=None,
                                              graph_type="predict",
                                              batch_size=batch_size)
            self.text_postproc = text_postproc if text_postproc else text_processor_from_proto(
                self.model_params.text_postprocessor, "post")
            self.data_preproc = data_preproc if data_preproc else data_processor_from_proto(
                self.model_params.data_preprocessor)
        elif network:
            self.model_params = None
            self.network_params = network.network_proto
            self.text_postproc = text_postproc
            self.data_preproc = data_preproc
            if not codec:
                raise Exception(
                    "A codec is required if preloaded network is used.")
        else:
            raise Exception(
                "Either a checkpoint or a existing backend must be provided")

        self.codec = codec if codec else Codec(self.model_params.codec.charset)
        self.out_to_in_trans = OutputToInputTransformer(
            self.data_preproc, self.network)
コード例 #16
0
ファイル: trainer.py プロジェクト: hajicj/ommr4all-server
    def _train(self,
               target_book: Optional[DatabaseBook] = None,
               callback: Optional[TrainerCallback] = None):
        if callback:
            callback.resolving_files()

        train_dataset = self.train_dataset.to_calamari_dataset(
            train=True, callback=callback)
        val_dataset = self.validation_dataset.to_calamari_dataset(
            train=True, callback=callback)

        params = CheckpointParams()

        params.max_iters = self.params.n_iter
        params.stats_size = 1000
        params.batch_size = 5
        params.checkpoint_frequency = 0
        params.output_dir = self.settings.model.path
        params.output_model_prefix = 'omr'
        params.display = self.params.display
        params.skip_invalid_gt = True
        params.processes = self.params.processes
        params.data_aug_retrain_on_original = False

        params.early_stopping_frequency = self.params.early_stopping_test_interval
        params.early_stopping_nbest = self.params.early_stopping_max_keep
        params.early_stopping_best_model_prefix = 'omr_best'
        params.early_stopping_best_model_output_dir = self.settings.model.path

        params.model.data_preprocessor.type = DataPreprocessorParams.NOOP_NORMALIZER
        # for preproc in [DataPreprocessorParams.RANGE_NORMALIZER, DataPreprocessorParams.FINAL_PREPARATION]:
        #    pp = params.model.data_preprocessor.children.add()
        #    pp.type = preproc
        params.model.text_preprocessor.type = TextProcessorParams.NOOP_NORMALIZER
        params.model.text_postprocessor.type = TextProcessorParams.NOOP_NORMALIZER

        params.model.line_height = self.settings.dataset_params.height
        params.model.network.channels = self.settings.calamari_params.channels

        network_str = self.settings.calamari_params.network
        if self.params.l_rate > 0:
            network_str += ',learning_rate={}'.format(self.params.l_rate)

        if self.settings.calamari_params.n_folds > 0:
            train_args = {
                "max_iters": params.max_iters,
                "stats_size": params.stats_size,
                "checkpoint_frequency": params.checkpoint_frequency,
                "pad": 0,
                "network": network_str,
                "early_stopping_frequency": params.early_stopping_frequency,
                "early_stopping_nbest": params.early_stopping_nbest,
                "line_height": params.model.line_height,
                "data_preprocessing":
                ["RANGE_NORMALIZER", "FINAL_PREPARATION"],
            }
            trainer = CrossFoldTrainer(self.settings.calamari_params.n_folds,
                                       train_dataset,
                                       params.output_dir,
                                       'omr_best_{id}',
                                       train_args,
                                       progress_bars=True)
            temporary_dir = os.path.join(params.output_dir, "temporary_dir")
            trainer.run(
                self.settings.calamari_params.single_folds,
                temporary_dir=temporary_dir,
                spawn_subprocesses=False,
                max_parallel_models=
                1,  # Force to run in same scope as parent process
            )
        else:
            network_params_from_definition_string(network_str,
                                                  params.model.network)
            trainer = Trainer(
                checkpoint_params=params,
                dataset=train_dataset,
                validation_dataset=val_dataset,
                n_augmentations=self.settings.page_segmentation_params.
                data_augmentation * 10,
                data_augmenter=SimpleDataAugmenter(),
                weights=None if not self.params.model_to_load() else
                self.params.model_to_load().local_file(
                    params.early_stopping_best_model_prefix + '.ckpt'),
                preload_training=True,
                preload_validation=True,
                codec=Codec(self.settings.dataset_params.calamari_codec.codec.
                            values()),
            )
            trainer.train()