def __init__(self, glyph_classifier_fn=None): """Creates the engine and TF graph for running OMR. Args: glyph_classifier_fn: Callable that loads the glyph classifier into the graph. Accepts a `Structure` as the single argument, and returns an instance of `BaseGlyphClassifier`. The function typically loads a TF saved model or other external data, and wraps the classification in a concrete glyph classifier subclass. If the classifier uses a `StafflineExtractor` for classification, it must set the `staffline_extractor` attribute of the `Structure`. Otherwise, glyph x coordinates will not be scaled back to image coordinates. """ glyph_classifier_fn = (glyph_classifier_fn or saved_classifier_fn.build_classifier_fn()) self.graph = tf.Graph() self.session = tf.Session(graph=self.graph) with self.graph.as_default(): with self.session.as_default(): with tf.name_scope('OMREngine'): self.png_path = tf.placeholder(tf.string, name='png_path', shape=()) self.image = image.decode_music_score_png( tf.read_file(self.png_path, name='page_image')) self.structure = structure_module.create_structure( self.image) # Loading saved models happens outside of the name scope, because scopes # can rename tensors from the model and cause dangling references. # TODO(ringw): TF should be able to load models gracefully within a # name scope. self.glyph_classifier = glyph_classifier_fn(self.structure)
def testCorpusImage(self): filename = os.path.join(tf.resource_loader.get_data_files_path(), '../testdata/IMSLP00747-000.png') image_contents = open(filename, 'rb').read() image_t = decode_music_score_png(tf.constant(image_contents)) staffdist_t, staffthick_t = ( staffline_distance.estimate_staffline_distance_and_thickness(image_t,)) with self.test_session() as sess: staffdist, staffthick = sess.run((staffdist_t, staffthick_t)) # Manually determined values for the image. self.assertAllEqual(staffdist, [16]) self.assertEquals(staffthick, 2)
def testCompute(self): filename = os.path.join(tf.resource_loader.get_data_files_path(), '../testdata/IMSLP00747-000.png') image = image_module.decode_music_score_png(tf.read_file(filename)) struct = structure.create_structure(image) with self.test_session(): struct = struct.compute() self.assertEqual(np.int32, struct.staff_detector.staves.dtype) # Expected number of staves for the corpus image. self.assertEqual((12, 2, 2), struct.staff_detector.staves.shape) self.assertEqual(np.int32, struct.verticals.lines.dtype) self.assertEqual(3, struct.verticals.lines.ndim) self.assertEqual((2, 2), struct.verticals.lines.shape[1:])
def test_corpus_image(self): filename = os.path.join(tf.resource_loader.get_data_files_path(), '../testdata/IMSLP00747-000.png') image_t = omr_image.decode_music_score_png(tf.read_file(filename)) remover = removal.StaffRemover(staves.StaffDetector(image_t)) with self.test_session() as sess: removed, image = sess.run([remover.remove_staves, image_t]) self.assertFalse(np.allclose(removed, image)) # If staff removal runs successfully, we should be unable to estimate the # staffline distance from the staves-removed image. est_staffline_distance, est_staffline_thickness = sess.run( staffline_distance.estimate_staffline_distance_and_thickness(removed)) print(est_staffline_distance) self.assertAllEqual([], est_staffline_distance) self.assertEqual(-1, est_staffline_thickness)
def testSaveAndLoadDummyClassifier(self): with tempfile.TemporaryDirectory() as base_dir: export_dir = os.path.join(base_dir, 'export') with self.test_session() as sess: patches = tf.placeholder(tf.float32, shape=(None, 18, 15)) num_patches = tf.shape(patches)[0] # Glyph.NONE is number 1. class_ids = tf.ones([num_patches], tf.int32) signature = tf.saved_model.signature_def_utils.build_signature_def( # pyformat: disable {'input': tf.saved_model.utils.build_tensor_info(patches)}, { 'class_ids': tf.saved_model.utils.build_tensor_info(class_ids) }, 'serve') builder = tf.saved_model.builder.SavedModelBuilder(export_dir) builder.add_meta_graph_and_variables( sess, ['serve'], signature_def_map={ tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature }) builder.save() tf.reset_default_graph() # Load the saved model. with self.test_session() as sess: filename = os.path.join( tf.resource_loader.get_data_files_path(), '../testdata/IMSLP00747-000.png') page = image.decode_music_score_png(tf.read_file(filename)) clazz = saved_classifier.SavedConvolutional1DClassifier( structure.create_structure(page), export_dir) # Run min length should be the default. self.assertEqual(clazz.run_min_length, convolutional.DEFAULT_RUN_MIN_LENGTH) predictions = clazz.staffline_predictions.eval() self.assertEqual(predictions.ndim, 3) # Staff, staff position, x self.assertGreater(predictions.size, 0) # Predictions are all musicscore_pb2.Glyph.NONE. self.assertAllEqual( predictions, np.full(predictions.shape, musicscore_pb2.Glyph.NONE, np.int32))
def test_corpus_image(self): # Test only the default staff detector (because projection won't detect all # staves). filename = os.path.join(tf.resource_loader.get_data_files_path(), '../testdata/IMSLP00747-000.png') image_t = omr_image.decode_music_score_png(tf.read_file(filename)) detector = staves.StaffDetector(image_t) with self.test_session() as sess: staves_arr, staffline_distances = sess.run( [detector.staves, detector.staffline_distance]) self.assertAllClose( np.mean(staves_arr[:, :, 1], axis=1), # average y position [ 413, 603, 848, 1040, 1286, 1476, 1724, 1915, 2162, 2354, 2604, 2795 ], atol=5) self.assertAllEqual(staffline_distances, [16] * 12)
def __init__(self, num_sections=DEFAULT_NUM_SECTIONS, patch_height=15, patch_width=12, run_options=None): self.num_sections = num_sections self.patch_height = patch_height self.patch_width = patch_width self.run_options = run_options self.graph = tf.Graph() with self.graph.as_default(): # Identifying information for the patch. self.filename = tf.placeholder(tf.string, name='filename') self.staff_index = tf.placeholder(tf.int64, name='staff_index') self.y_position = tf.placeholder(tf.int64, name='y_position') image = image_module.decode_music_score_png( tf.read_file(self.filename)) staff_detector = staves_module.StaffDetector(image) staff_remover = removal.StaffRemover(staff_detector) extractor = StafflineExtractor(staff_remover.remove_staves, staff_detector, num_sections=num_sections, target_height=patch_height) # Index into the staff strips array, where a y position of 0 is the center # element. Positive positions count up (towards higher notes, towards the # top of the image, and smaller indices into the array). position_index = num_sections // 2 - self.y_position self.all_stafflines = extractor.extract_staves() # The entire extracted horizontal strip of the image. self.staffline = self.all_stafflines[self.staff_index, position_index] # Determine the scale for converting image x coordinates to the scaled # staff strip from which the patch is extracted. extracted_staff_strip_height = tf.shape(self.all_stafflines)[2] unscaled_staff_strip_heights = tf.multiply( DEFAULT_STAFFLINE_DISTANCE_MULTIPLE, staff_detector.staffline_distance) self.all_staffline_scales = tf.divide( tf.to_float(extracted_staff_strip_height), tf.to_float(unscaled_staff_strip_heights)) self.staffline_scale = self.all_staffline_scales[self.staff_index]
def pipeline_graph(png_path, staffline_height, patch_width, num_stafflines): """Constructs the graph for the staffline patches pipeline. Args: png_path: Path to the input png. String scalar tensor. staffline_height: Height of a staffline. int. patch_width: Width of a patch. int. num_stafflines: Number of stafflines to extract around each staff. int. Returns: A tensor representing the staffline patches. float32 with shape (num_patches, staffline_height, patch_width). """ image_t = image.decode_music_score_png(tf.read_file(png_path)) staff_detector = staves.StaffDetector(image_t) staff_remover = removal.StaffRemover(staff_detector) stafflines = tf.identity(staffline_extractor.StafflineExtractor( staff_remover.remove_staves, staff_detector, target_height=staffline_height, num_sections=num_stafflines).extract_staves(), name='stafflines') return _extract_patches(stafflines, patch_width)