Exemplo n.º 1
0
    def run_single(self):
        args = self.args
        fold_log = self.fold_log
        from omr.steps.algorithm import AlgorithmPredictorSettings, AlgorithmTrainerSettings, AlgorithmTrainerParams
        from omr.steps.step import Step
        global_args = args.global_args


        def print_dataset_content(files: List[PcGts], label: str):
            fold_log.debug("Got {} {} files: {}".format(len(files), label, [f.page.location.local_path() for f in files]))

        print_dataset_content(args.train_pcgts_files, 'training')
        if args.validation_pcgts_files:
            print_dataset_content(args.validation_pcgts_files, 'validation')
        else:
            fold_log.debug("No validation data. Using training data instead")
        print_dataset_content(args.test_pcgts_files, 'testing')

        if not os.path.exists(args.model_dir):
            os.makedirs(args.model_dir)

        prediction_path = os.path.join(args.model_dir, 'pred.json')
        model_path = os.path.join(args.model_dir, 'best')

        if not global_args.skip_train:
            fold_log.info("Starting training")
            trainer = Step.create_trainer(
                global_args.algorithm_type,
                AlgorithmTrainerSettings(
                    dataset_params=args.global_args.dataset_params,
                    train_data=args.train_pcgts_files,
                    validation_data=args.validation_pcgts_files if args.validation_pcgts_files else args.train_pcgts_files,
                    model=Model(MetaId.from_custom_path(model_path, global_args.algorithm_type)),
                    params=global_args.trainer_params,
                    page_segmentation_params=global_args.page_segmentation_params,
                    calamari_params=global_args.calamari_params,
                )
            )
            trainer.train()

        test_pcgts_files = args.test_pcgts_files
        if not global_args.skip_predict:
            fold_log.info("Starting prediction")
            pred_params = deepcopy(global_args.predictor_params)
            pred_params.modelId = MetaId.from_custom_path(model_path, global_args.algorithm_type)
            if global_args.calamari_dictionary_from_gt:
                words = set()
                for pcgts in test_pcgts_files:
                    words = words.union(sum([t.sentence.text().replace('-', '').split() for t in pcgts.page.all_text_lines()], []))
                pred_params.ctcDecoder.params.dictionary[:] = words

            pred = Step.create_predictor(
                global_args.algorithm_type,
                AlgorithmPredictorSettings(
                    None,
                    pred_params,
                ))
            full_predictions = list(pred.predict([f.page.location for f in test_pcgts_files]))
            predictions = self.extract_gt_prediction(full_predictions)
            with open(prediction_path, 'wb') as f:
                pickle.dump(predictions, f)

            if global_args.output_book:
                fold_log.info("Outputting data")
                pred_book = DatabaseBook(global_args.output_book)
                if not pred_book.exists():
                    pred_book_meta = DatabaseBookMeta(pred_book.book, pred_book.book)
                    pred_book.create(pred_book_meta)

                output_pcgts = [PcGts.from_file(pcgts.page.location.copy_to(pred_book).file('pcgts'))
                                for pcgts in test_pcgts_files]

                self.output_prediction_to_book(pred_book, output_pcgts, full_predictions)

                for o_pcgts in output_pcgts:
                    o_pcgts.to_file(o_pcgts.page.location.file('pcgts').local_path())

        else:
            fold_log.info("Skipping prediction")

            with open(prediction_path, 'rb') as f:
                predictions = pickle.load(f)

        predictions = tuple(predictions)
        if not global_args.skip_eval and len(predictions) > 0:
            fold_log.info("Starting evaluation")
            r = self.evaluate(predictions, global_args.evaluation_params)
        else:
            r = None

        # if not global_args.skip_cleanup:
        #    fold_log.info("Cleanup")
        #    shutil.rmtree(args.model_dir)

        return r
Exemplo n.º 2
0
if __name__ == "__main__":
    from database import DatabaseBook
    from PIL import Image
    import matplotlib.pyplot as plt
    import numpy as np
    from omr.steps.step import Step, AlgorithmTypes

    b = DatabaseBook('demo')
    p = b.page('page00000001')
    img = np.array(Image.open(p.file('color_norm').local_path()))
    mask = np.zeros(img.shape, np.float) + 255
    val_pcgts = [PcGts.from_file(p.file('pcgts'))]

    settings = PredictorSettings()
    pred = Step.create_predictor(AlgorithmTypes.LAYOUT_SIMPLE_BOUNDING_BOXES, settings)

    def s(c):
        return val_pcgts[0].page.page_to_image_scale(c, settings.page_scale_reference)

    for p in pred.predict(val_pcgts):
        for i, mr_c in enumerate(p.blocks.get(BlockType.MUSIC, [])):
            s(mr_c.coords).draw(mask, (255, 0, 0), fill=True, thickness=0)

        for i, mr_c in enumerate(p.blocks.get(BlockType.LYRICS, [])):
            s(mr_c.coords).draw(mask, (0, 255, 0), fill=True, thickness=0)

        for i, mr_c in enumerate(p.blocks.get(BlockType.DROP_CAPITAL, [])):
            s(mr_c.coords).draw(mask, (0, 0, 255), fill=True, thickness=0)

    import json