示例#1
0
    def __init__(self, glyph_classifier_fn=None):
        """Creates the engine and TF graph for running OMR.

    Args:
      glyph_classifier_fn: Callable that loads the glyph classifier into the
        graph. Accepts a `Structure` as the single argument, and returns an
        instance of `BaseGlyphClassifier`. The function typically loads a TF
        saved model or other external data, and wraps the classification in a
        concrete glyph classifier subclass. If the classifier uses a
        `StafflineExtractor` for classification, it must set the
        `staffline_extractor` attribute of the `Structure`. Otherwise, glyph x
        coordinates will not be scaled back to image coordinates.
    """
        glyph_classifier_fn = (glyph_classifier_fn
                               or saved_classifier_fn.build_classifier_fn())
        self.graph = tf.Graph()
        self.session = tf.Session(graph=self.graph)
        with self.graph.as_default():
            with self.session.as_default():
                with tf.name_scope('OMREngine'):
                    self.png_path = tf.placeholder(tf.string,
                                                   name='png_path',
                                                   shape=())
                    self.image = image.decode_music_score_png(
                        tf.read_file(self.png_path, name='page_image'))
                    self.structure = structure_module.create_structure(
                        self.image)
                # Loading saved models happens outside of the name scope, because scopes
                # can rename tensors from the model and cause dangling references.
                # TODO(ringw): TF should be able to load models gracefully within a
                # name scope.
                self.glyph_classifier = glyph_classifier_fn(self.structure)
 def testCorpusImage(self):
   filename = os.path.join(tf.resource_loader.get_data_files_path(),
                           '../testdata/IMSLP00747-000.png')
   image_contents = open(filename, 'rb').read()
   image_t = decode_music_score_png(tf.constant(image_contents))
   staffdist_t, staffthick_t = (
       staffline_distance.estimate_staffline_distance_and_thickness(image_t,))
   with self.test_session() as sess:
     staffdist, staffthick = sess.run((staffdist_t, staffthick_t))
   # Manually determined values for the image.
   self.assertAllEqual(staffdist, [16])
   self.assertEquals(staffthick, 2)
示例#3
0
    def testCompute(self):
        filename = os.path.join(tf.resource_loader.get_data_files_path(),
                                '../testdata/IMSLP00747-000.png')
        image = image_module.decode_music_score_png(tf.read_file(filename))
        struct = structure.create_structure(image)
        with self.test_session():
            struct = struct.compute()
        self.assertEqual(np.int32, struct.staff_detector.staves.dtype)
        # Expected number of staves for the corpus image.
        self.assertEqual((12, 2, 2), struct.staff_detector.staves.shape)

        self.assertEqual(np.int32, struct.verticals.lines.dtype)
        self.assertEqual(3, struct.verticals.lines.ndim)
        self.assertEqual((2, 2), struct.verticals.lines.shape[1:])
示例#4
0
 def test_corpus_image(self):
   filename = os.path.join(tf.resource_loader.get_data_files_path(),
                           '../testdata/IMSLP00747-000.png')
   image_t = omr_image.decode_music_score_png(tf.read_file(filename))
   remover = removal.StaffRemover(staves.StaffDetector(image_t))
   with self.test_session() as sess:
     removed, image = sess.run([remover.remove_staves, image_t])
     self.assertFalse(np.allclose(removed, image))
     # If staff removal runs successfully, we should be unable to estimate the
     # staffline distance from the staves-removed image.
     est_staffline_distance, est_staffline_thickness = sess.run(
         staffline_distance.estimate_staffline_distance_and_thickness(removed))
     print(est_staffline_distance)
     self.assertAllEqual([], est_staffline_distance)
     self.assertEqual(-1, est_staffline_thickness)
示例#5
0
 def testSaveAndLoadDummyClassifier(self):
     with tempfile.TemporaryDirectory() as base_dir:
         export_dir = os.path.join(base_dir, 'export')
         with self.test_session() as sess:
             patches = tf.placeholder(tf.float32, shape=(None, 18, 15))
             num_patches = tf.shape(patches)[0]
             # Glyph.NONE is number 1.
             class_ids = tf.ones([num_patches], tf.int32)
             signature = tf.saved_model.signature_def_utils.build_signature_def(
                 # pyformat: disable
                 {'input': tf.saved_model.utils.build_tensor_info(patches)},
                 {
                     'class_ids':
                     tf.saved_model.utils.build_tensor_info(class_ids)
                 },
                 'serve')
             builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
             builder.add_meta_graph_and_variables(
                 sess, ['serve'],
                 signature_def_map={
                     tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                     signature
                 })
             builder.save()
         tf.reset_default_graph()
         # Load the saved model.
         with self.test_session() as sess:
             filename = os.path.join(
                 tf.resource_loader.get_data_files_path(),
                 '../testdata/IMSLP00747-000.png')
             page = image.decode_music_score_png(tf.read_file(filename))
             clazz = saved_classifier.SavedConvolutional1DClassifier(
                 structure.create_structure(page), export_dir)
             # Run min length should be the default.
             self.assertEqual(clazz.run_min_length,
                              convolutional.DEFAULT_RUN_MIN_LENGTH)
             predictions = clazz.staffline_predictions.eval()
             self.assertEqual(predictions.ndim,
                              3)  # Staff, staff position, x
             self.assertGreater(predictions.size, 0)
             # Predictions are all musicscore_pb2.Glyph.NONE.
             self.assertAllEqual(
                 predictions,
                 np.full(predictions.shape, musicscore_pb2.Glyph.NONE,
                         np.int32))
示例#6
0
 def test_corpus_image(self):
     # Test only the default staff detector (because projection won't detect all
     # staves).
     filename = os.path.join(tf.resource_loader.get_data_files_path(),
                             '../testdata/IMSLP00747-000.png')
     image_t = omr_image.decode_music_score_png(tf.read_file(filename))
     detector = staves.StaffDetector(image_t)
     with self.test_session() as sess:
         staves_arr, staffline_distances = sess.run(
             [detector.staves, detector.staffline_distance])
     self.assertAllClose(
         np.mean(staves_arr[:, :, 1], axis=1),  # average y position
         [
             413, 603, 848, 1040, 1286, 1476, 1724, 1915, 2162, 2354, 2604,
             2795
         ],
         atol=5)
     self.assertAllEqual(staffline_distances, [16] * 12)
示例#7
0
    def __init__(self,
                 num_sections=DEFAULT_NUM_SECTIONS,
                 patch_height=15,
                 patch_width=12,
                 run_options=None):
        self.num_sections = num_sections
        self.patch_height = patch_height
        self.patch_width = patch_width
        self.run_options = run_options

        self.graph = tf.Graph()
        with self.graph.as_default():
            # Identifying information for the patch.
            self.filename = tf.placeholder(tf.string, name='filename')
            self.staff_index = tf.placeholder(tf.int64, name='staff_index')
            self.y_position = tf.placeholder(tf.int64, name='y_position')

            image = image_module.decode_music_score_png(
                tf.read_file(self.filename))
            staff_detector = staves_module.StaffDetector(image)
            staff_remover = removal.StaffRemover(staff_detector)
            extractor = StafflineExtractor(staff_remover.remove_staves,
                                           staff_detector,
                                           num_sections=num_sections,
                                           target_height=patch_height)
            # Index into the staff strips array, where a y position of 0 is the center
            # element. Positive positions count up (towards higher notes, towards the
            # top of the image, and smaller indices into the array).
            position_index = num_sections // 2 - self.y_position
            self.all_stafflines = extractor.extract_staves()
            # The entire extracted horizontal strip of the image.
            self.staffline = self.all_stafflines[self.staff_index,
                                                 position_index]

            # Determine the scale for converting image x coordinates to the scaled
            # staff strip from which the patch is extracted.
            extracted_staff_strip_height = tf.shape(self.all_stafflines)[2]
            unscaled_staff_strip_heights = tf.multiply(
                DEFAULT_STAFFLINE_DISTANCE_MULTIPLE,
                staff_detector.staffline_distance)
            self.all_staffline_scales = tf.divide(
                tf.to_float(extracted_staff_strip_height),
                tf.to_float(unscaled_staff_strip_heights))
            self.staffline_scale = self.all_staffline_scales[self.staff_index]
示例#8
0
def pipeline_graph(png_path, staffline_height, patch_width, num_stafflines):
    """Constructs the graph for the staffline patches pipeline.

  Args:
    png_path: Path to the input png. String scalar tensor.
    staffline_height: Height of a staffline. int.
    patch_width: Width of a patch. int.
    num_stafflines: Number of stafflines to extract around each staff. int.

  Returns:
    A tensor representing the staffline patches. float32 with shape
        (num_patches, staffline_height, patch_width).
  """
    image_t = image.decode_music_score_png(tf.read_file(png_path))
    staff_detector = staves.StaffDetector(image_t)
    staff_remover = removal.StaffRemover(staff_detector)
    stafflines = tf.identity(staffline_extractor.StafflineExtractor(
        staff_remover.remove_staves,
        staff_detector,
        target_height=staffline_height,
        num_sections=num_stafflines).extract_staves(),
                             name='stafflines')
    return _extract_patches(stafflines, patch_width)