コード例 #1
0
ファイル: algorithm.py プロジェクト: hajicj/ommr4all-server
    def default_model_for_style(cls, style: str) -> Optional[Model]:
        models = ModelsId.from_internal(style, cls.type())
        model = Model(MetaId(models, cls.model_dir()))
        if model.exists():
            return model

        # fallback: french14 must exist
        return cls.default_model_for_style('french14')
コード例 #2
0
ファイル: algorithm.py プロジェクト: hajicj/ommr4all-server
    def create_new_model(cls,
                         book: DatabaseBook,
                         id: Optional[str] = None) -> Model:
        import datetime

        id = id if id else str(uuid.uuid4())
        time = datetime.datetime.now()
        models = ModelsId.from_external(book.book, cls.type())
        return Model(MetaId(models, time.strftime("%Y-%m-%dT%H:%M:%S")),
                     ModelMeta(id, time, style=book.get_meta().notationStyle))
コード例 #3
0
 def put(self, request, group, style):
     meta = ModelMeta.from_json(request.body)
     default_type = AlgorithmGroups(group).types()[0]
     model = Model.from_id_str(meta.id)
     target_meta = MetaId(
         DatabaseAvailableModels.local_default_models(style, default_type),
         Step.meta(default_type).model_dir())
     target_model = Model(target_meta)
     model.copy_to(target_model, override=True)
     return Response()
コード例 #4
0
    def _test_list_models(self, operation: AlgorithmTypes):
        from database.database_available_models import DatabaseAvailableModels
        from database.model import ModelsId, MetaId
        book = DatabaseBook('demo')
        response = self.client.get('/api/book/{}/operation/{}/models'.format(
            book.book, operation.value))
        self.assertEqual(response.status_code, status.HTTP_200_OK, response)

        models_id = ModelsId.from_internal('french14',
                                           AlgorithmTypes.STAFF_LINES_PC)
        models = DatabaseAvailableModels.from_dict(response.data)
        self.assertEqual(models.newest_model, None)
        self.assertEqual(
            models.selected_model.id,
            str(MetaId(models_id, models_id.algorithm_type.value)))
        self.assertEqual(
            models.default_book_style_model.id,
            str(MetaId(models_id, models_id.algorithm_type.value)))
        self.assertListEqual(models.book_models, [])
        self.assertListEqual(models.models_of_same_book_style, [])
コード例 #5
0
ファイル: predictor.py プロジェクト: hajicj/ommr4all-server
    def __init__(self, settings: AlgorithmPredictorSettings):
        super().__init__(settings)

        meta = Step.meta(AlgorithmTypes.OCR_CALAMARI)
        from ommr4all.settings import BASE_DIR
        model = Model(
            MetaId.from_custom_path(
                BASE_DIR +
                '/internal_storage/default_models/fraktur/text_calamari/',
                meta.type()))
        settings = AlgorithmPredictorSettings(model=model, )
        settings.params.ctcDecoder.params.type = CTCDecoderParams.CTC_DEFAULT
        self.ocr_predictor = meta.create_predictor(settings)
コード例 #6
0
    def __init__(self, settings: AlgorithmPredictorSettings):
        super().__init__(settings)
        self.document_id = settings.params.documentId
        self.document_text = settings.params.documentText

        self.document_similar_tester = SimilarDocumentChecker()
        self.text_normalizer = LyricsNormalizationProcessor(
            LyricsNormalizationParams(LyricsNormalization.WORDS))
        meta = Step.meta(AlgorithmTypes.OCR_CALAMARI)
        from ommr4all.settings import BASE_DIR
        model = Model(
            MetaId.from_custom_path(
                BASE_DIR +
                '/internal_storage/default_models/fraktur/text_calamari/',
                meta.type()))
        settings = AlgorithmPredictorSettings(model=model, )
        settings.params.ctcDecoder.params.type = CTCDecoderParams.CTC_DEFAULT
        self.ocr_predictor = meta.create_predictor(settings)
コード例 #7
0
    def run_single(self):
        args = self.args
        fold_log = self.fold_log
        from omr.steps.algorithm import AlgorithmPredictorSettings, AlgorithmTrainerSettings, AlgorithmTrainerParams
        from omr.steps.step import Step
        global_args = args.global_args


        def print_dataset_content(files: List[PcGts], label: str):
            fold_log.debug("Got {} {} files: {}".format(len(files), label, [f.page.location.local_path() for f in files]))

        print_dataset_content(args.train_pcgts_files, 'training')
        if args.validation_pcgts_files:
            print_dataset_content(args.validation_pcgts_files, 'validation')
        else:
            fold_log.debug("No validation data. Using training data instead")
        print_dataset_content(args.test_pcgts_files, 'testing')

        if not os.path.exists(args.model_dir):
            os.makedirs(args.model_dir)

        prediction_path = os.path.join(args.model_dir, 'pred.json')
        model_path = os.path.join(args.model_dir, 'best')

        if not global_args.skip_train:
            fold_log.info("Starting training")
            trainer = Step.create_trainer(
                global_args.algorithm_type,
                AlgorithmTrainerSettings(
                    dataset_params=args.global_args.dataset_params,
                    train_data=args.train_pcgts_files,
                    validation_data=args.validation_pcgts_files if args.validation_pcgts_files else args.train_pcgts_files,
                    model=Model(MetaId.from_custom_path(model_path, global_args.algorithm_type)),
                    params=global_args.trainer_params,
                    page_segmentation_params=global_args.page_segmentation_params,
                    calamari_params=global_args.calamari_params,
                )
            )
            trainer.train()

        test_pcgts_files = args.test_pcgts_files
        if not global_args.skip_predict:
            fold_log.info("Starting prediction")
            pred_params = deepcopy(global_args.predictor_params)
            pred_params.modelId = MetaId.from_custom_path(model_path, global_args.algorithm_type)
            if global_args.calamari_dictionary_from_gt:
                words = set()
                for pcgts in test_pcgts_files:
                    words = words.union(sum([t.sentence.text().replace('-', '').split() for t in pcgts.page.all_text_lines()], []))
                pred_params.ctcDecoder.params.dictionary[:] = words

            pred = Step.create_predictor(
                global_args.algorithm_type,
                AlgorithmPredictorSettings(
                    None,
                    pred_params,
                ))
            full_predictions = list(pred.predict([f.page.location for f in test_pcgts_files]))
            predictions = self.extract_gt_prediction(full_predictions)
            with open(prediction_path, 'wb') as f:
                pickle.dump(predictions, f)

            if global_args.output_book:
                fold_log.info("Outputting data")
                pred_book = DatabaseBook(global_args.output_book)
                if not pred_book.exists():
                    pred_book_meta = DatabaseBookMeta(pred_book.book, pred_book.book)
                    pred_book.create(pred_book_meta)

                output_pcgts = [PcGts.from_file(pcgts.page.location.copy_to(pred_book).file('pcgts'))
                                for pcgts in test_pcgts_files]

                self.output_prediction_to_book(pred_book, output_pcgts, full_predictions)

                for o_pcgts in output_pcgts:
                    o_pcgts.to_file(o_pcgts.page.location.file('pcgts').local_path())

        else:
            fold_log.info("Skipping prediction")

            with open(prediction_path, 'rb') as f:
                predictions = pickle.load(f)

        predictions = tuple(predictions)
        if not global_args.skip_eval and len(predictions) > 0:
            fold_log.info("Starting evaluation")
            r = self.evaluate(predictions, global_args.evaluation_params)
        else:
            r = None

        # if not global_args.skip_cleanup:
        #    fold_log.info("Cleanup")
        #    shutil.rmtree(args.model_dir)

        return r
コード例 #8
0
ファイル: algorithm.py プロジェクト: hajicj/ommr4all-server
 def default_model_for_book(cls, book: DatabaseBook) -> Model:
     models = ModelsId.from_internal(book.get_meta().notationStyle,
                                     cls.type())
     return Model(MetaId(models, cls.model_dir()))
コード例 #9
0
     minNumberOfStaffLines=args.min_number_of_staff_lines,
     maxNumberOfStaffLines=args.max_number_of_staff_lines,
     ctcDecoder=SerializableCTCDecoderParams(
         type=CTCDecoderParams.CTCDecoderType.Value(args.calamari_ctc_decoder),
         beam_width=args.calamari_ctc_decoder_beam_width,
         word_separator='' if args.calamari_ctc_word_separator.upper() == "BLANK" else ' ' if args.calamari_ctc_word_separator.upper() == "SPACE" else args.calamari_ctc_word_separator,
         dictionary=None if not args.calamari_ctc_dictionary else open(args.calamari_ctc_dictionary, 'r').read().split(),
         non_word_chars=args.calamari_ctc_decoder_non_word_chars,
     )
 ),
 output_book=args.output_book,
 algorithm_type=args.type,
 trainer_params=AlgorithmTrainerParams(
     n_iter=args.n_iter,
     display=100,
     load=str(MetaId.from_custom_path(args.pretrained_model, args.type)) if args.pretrained_model else None,
     processes=8,
     early_stopping_at_acc=args.early_stopping_at_accuracy,
     early_stopping_max_keep=args.early_stopping_max_keep,
     early_stopping_test_interval=args.early_stopping_test_interval,
     train_data_multiplier=args.train_data_multiplier,
     data_augmentation_factor=args.data_augmentation_factor,
 ),
 page_segmentation_params=PageSegmentationTrainerParams(
     data_augmentation=args.data_augmentation,
     architecture=args.page_segmentation_architecture,
 ),
 calamari_params=CalamariParams(
     network=args.calamari_network,
     n_folds=args.calamari_n_folds,
     single_folds=args.calamari_single_folds,