def run_single(self): args = self.args fold_log = self.fold_log from omr.steps.algorithm import AlgorithmPredictorSettings, AlgorithmTrainerSettings, AlgorithmTrainerParams from omr.steps.step import Step global_args = args.global_args def print_dataset_content(files: List[PcGts], label: str): fold_log.debug("Got {} {} files: {}".format(len(files), label, [f.page.location.local_path() for f in files])) print_dataset_content(args.train_pcgts_files, 'training') if args.validation_pcgts_files: print_dataset_content(args.validation_pcgts_files, 'validation') else: fold_log.debug("No validation data. Using training data instead") print_dataset_content(args.test_pcgts_files, 'testing') if not os.path.exists(args.model_dir): os.makedirs(args.model_dir) prediction_path = os.path.join(args.model_dir, 'pred.json') model_path = os.path.join(args.model_dir, 'best') if not global_args.skip_train: fold_log.info("Starting training") trainer = Step.create_trainer( global_args.algorithm_type, AlgorithmTrainerSettings( dataset_params=args.global_args.dataset_params, train_data=args.train_pcgts_files, validation_data=args.validation_pcgts_files if args.validation_pcgts_files else args.train_pcgts_files, model=Model(MetaId.from_custom_path(model_path, global_args.algorithm_type)), params=global_args.trainer_params, page_segmentation_params=global_args.page_segmentation_params, calamari_params=global_args.calamari_params, ) ) trainer.train() test_pcgts_files = args.test_pcgts_files if not global_args.skip_predict: fold_log.info("Starting prediction") pred_params = deepcopy(global_args.predictor_params) pred_params.modelId = MetaId.from_custom_path(model_path, global_args.algorithm_type) if global_args.calamari_dictionary_from_gt: words = set() for pcgts in test_pcgts_files: words = words.union(sum([t.sentence.text().replace('-', '').split() for t in pcgts.page.all_text_lines()], [])) pred_params.ctcDecoder.params.dictionary[:] = words pred = Step.create_predictor( global_args.algorithm_type, AlgorithmPredictorSettings( None, pred_params, )) full_predictions = list(pred.predict([f.page.location for f in test_pcgts_files])) predictions = self.extract_gt_prediction(full_predictions) with open(prediction_path, 'wb') as f: pickle.dump(predictions, f) if global_args.output_book: fold_log.info("Outputting data") pred_book = DatabaseBook(global_args.output_book) if not pred_book.exists(): pred_book_meta = DatabaseBookMeta(pred_book.book, pred_book.book) pred_book.create(pred_book_meta) output_pcgts = [PcGts.from_file(pcgts.page.location.copy_to(pred_book).file('pcgts')) for pcgts in test_pcgts_files] self.output_prediction_to_book(pred_book, output_pcgts, full_predictions) for o_pcgts in output_pcgts: o_pcgts.to_file(o_pcgts.page.location.file('pcgts').local_path()) else: fold_log.info("Skipping prediction") with open(prediction_path, 'rb') as f: predictions = pickle.load(f) predictions = tuple(predictions) if not global_args.skip_eval and len(predictions) > 0: fold_log.info("Starting evaluation") r = self.evaluate(predictions, global_args.evaluation_params) else: r = None # if not global_args.skip_cleanup: # fold_log.info("Cleanup") # shutil.rmtree(args.model_dir) return r
if __name__ == "__main__": from database import DatabaseBook from PIL import Image import matplotlib.pyplot as plt import numpy as np from omr.steps.step import Step, AlgorithmTypes b = DatabaseBook('demo') p = b.page('page00000001') img = np.array(Image.open(p.file('color_norm').local_path())) mask = np.zeros(img.shape, np.float) + 255 val_pcgts = [PcGts.from_file(p.file('pcgts'))] settings = PredictorSettings() pred = Step.create_predictor(AlgorithmTypes.LAYOUT_SIMPLE_BOUNDING_BOXES, settings) def s(c): return val_pcgts[0].page.page_to_image_scale(c, settings.page_scale_reference) for p in pred.predict(val_pcgts): for i, mr_c in enumerate(p.blocks.get(BlockType.MUSIC, [])): s(mr_c.coords).draw(mask, (255, 0, 0), fill=True, thickness=0) for i, mr_c in enumerate(p.blocks.get(BlockType.LYRICS, [])): s(mr_c.coords).draw(mask, (0, 255, 0), fill=True, thickness=0) for i, mr_c in enumerate(p.blocks.get(BlockType.DROP_CAPITAL, [])): s(mr_c.coords).draw(mask, (0, 0, 255), fill=True, thickness=0) import json