Example #1
0
    def _note_metrics(labels, predictions):
        """A pyfunc that wraps a call to precision_recall_f1_overlap."""
        est_sequence = pianoroll_to_note_sequence(
            predictions,
            frames_per_second=data.hparams_frames_per_second(hparams),
            min_duration_ms=hparams.min_duration_ms)

        ref_sequence = pianoroll_to_note_sequence(
            labels,
            frames_per_second=data.hparams_frames_per_second(hparams),
            min_duration_ms=hparams.min_duration_ms)

        est_intervals, est_pitches = sequence_to_valued_intervals(
            est_sequence, hparams.min_duration_ms)
        ref_intervals, ref_pitches = sequence_to_valued_intervals(
            ref_sequence, hparams.min_duration_ms)

        if est_intervals.size == 0 or ref_intervals.size == 0:
            return 0., 0., 0.
        note_precision, note_recall, note_f1, _ = precision_recall_f1_overlap(
            ref_intervals,
            pretty_midi.note_number_to_hz(ref_pitches),
            est_intervals,
            pretty_midi.note_number_to_hz(est_pitches),
            offset_ratio=offset_ratio)

        return note_precision, note_recall, note_f1
Example #2
0
  def _note_metrics(labels, predictions):
    """A pyfunc that wraps a call to precision_recall_f1_overlap."""
    est_sequence = sequences_lib.pianoroll_to_note_sequence(
        predictions,
        frames_per_second=data.hparams_frames_per_second(hparams),
        min_duration_ms=hparams.min_duration_ms)

    ref_sequence = sequences_lib.pianoroll_to_note_sequence(
        labels,
        frames_per_second=data.hparams_frames_per_second(hparams),
        min_duration_ms=hparams.min_duration_ms)

    est_intervals, est_pitches, _ = infer_util.sequence_to_valued_intervals(
        est_sequence, hparams.min_duration_ms)
    ref_intervals, ref_pitches, _ = infer_util.sequence_to_valued_intervals(
        ref_sequence, hparams.min_duration_ms)

    if est_intervals.size == 0 or ref_intervals.size == 0:
      return 0., 0., 0.
    note_precision, note_recall, note_f1, _ = precision_recall_f1_overlap(
        ref_intervals,
        pretty_midi.note_number_to_hz(ref_pitches),
        est_intervals,
        pretty_midi.note_number_to_hz(est_pitches),
        offset_ratio=offset_ratio)

    return note_precision, note_recall, note_f1
Example #3
0
    def testSequenceToValuedIntervals(self):
        sequence = music_pb2.NoteSequence()
        sequence.notes.add(pitch=60, start_time=1.0, end_time=2.0)
        # Should be dropped because it is 0 duration.
        sequence.notes.add(pitch=60, start_time=3.0, end_time=3.0)

        intervals, pitches = infer_util.sequence_to_valued_intervals(
            sequence, min_duration_ms=0)
        np.testing.assert_array_equal([[1., 2.]], intervals)
        np.testing.assert_array_equal([60], pitches)
Example #4
0
  def testSequenceToValuedIntervals(self):
    sequence = music_pb2.NoteSequence()
    sequence.notes.add(pitch=60, start_time=1.0, end_time=2.0)
    # Should be dropped because it is 0 duration.
    sequence.notes.add(pitch=60, start_time=3.0, end_time=3.0)

    intervals, pitches = infer_util.sequence_to_valued_intervals(
        sequence, min_duration_ms=0)
    np.testing.assert_array_equal([[1., 2.]], intervals)
    np.testing.assert_array_equal([60], pitches)
Example #5
0
def _calculate_metrics_py(
    frame_predictions, onset_predictions, offset_predictions, velocity_values,
    sequence_label_str, frame_labels, sequence_id, hparams):
  """Python logic for calculating metrics on a single example."""
  tf.logging.info('Calculating metrics for %s with length %d', sequence_id,
                  frame_labels.shape[0])
  if not hparams.predict_onset_threshold:
    onset_predictions = None
  if not hparams.predict_offset_threshold:
    offset_predictions = None

  sequence_prediction = sequences_lib.pianoroll_to_note_sequence(
      frames=frame_predictions,
      frames_per_second=data.hparams_frames_per_second(hparams),
      min_duration_ms=0,
      min_midi_pitch=constants.MIN_MIDI_PITCH,
      onset_predictions=onset_predictions,
      offset_predictions=offset_predictions,
      velocity_values=velocity_values)

  sequence_label = music_pb2.NoteSequence.FromString(sequence_label_str)

  if hparams.backward_shift_amount_ms:

    def shift_notesequence(ns_time):
      return ns_time + hparams.backward_shift_amount_ms / 1000.

    shifted_sequence_label, skipped_notes = (
        sequences_lib.adjust_notesequence_times(sequence_label,
                                                shift_notesequence))
    assert skipped_notes == 0
    sequence_label = shifted_sequence_label

  est_intervals, est_pitches, est_velocities = (
      infer_util.sequence_to_valued_intervals(sequence_prediction))

  ref_intervals, ref_pitches, ref_velocities = (
      infer_util.sequence_to_valued_intervals(sequence_label))

  note_precision, note_recall, note_f1, _ = (
      mir_eval.transcription.precision_recall_f1_overlap(
          ref_intervals,
          pretty_midi.note_number_to_hz(ref_pitches),
          est_intervals,
          pretty_midi.note_number_to_hz(est_pitches),
          offset_ratio=None))

  (note_with_offsets_precision, note_with_offsets_recall, note_with_offsets_f1,
   _) = (
       mir_eval.transcription.precision_recall_f1_overlap(
           ref_intervals, pretty_midi.note_number_to_hz(ref_pitches),
           est_intervals, pretty_midi.note_number_to_hz(est_pitches)))

  (note_with_offsets_velocity_precision, note_with_offsets_velocity_recall,
   note_with_offsets_velocity_f1, _) = (
       mir_eval.transcription_velocity.precision_recall_f1_overlap(
           ref_intervals=ref_intervals,
           ref_pitches=pretty_midi.note_number_to_hz(ref_pitches),
           ref_velocities=ref_velocities,
           est_intervals=est_intervals,
           est_pitches=pretty_midi.note_number_to_hz(est_pitches),
           est_velocities=est_velocities))

  processed_frame_predictions = sequences_lib.sequence_to_pianoroll(
      sequence_prediction,
      frames_per_second=data.hparams_frames_per_second(hparams),
      min_pitch=constants.MIN_MIDI_PITCH,
      max_pitch=constants.MAX_MIDI_PITCH).active

  if processed_frame_predictions.shape[0] < frame_labels.shape[0]:
    # Pad transcribed frames with silence.
    pad_length = frame_labels.shape[0] - processed_frame_predictions.shape[0]
    processed_frame_predictions = np.pad(processed_frame_predictions,
                                         [(0, pad_length), (0, 0)], 'constant')
  elif processed_frame_predictions.shape[0] > frame_labels.shape[0]:
    # Truncate transcribed frames.
    processed_frame_predictions = (
        processed_frame_predictions[:frame_labels.shape[0], :])

  tf.logging.info(
      'Metrics for %s: Note F1 %f, Note w/ offsets F1 %f, '
      'Note w/ offsets & velocity: %f', sequence_id, note_f1,
      note_with_offsets_f1, note_with_offsets_velocity_f1)
  return (note_precision, note_recall, note_f1, note_with_offsets_precision,
          note_with_offsets_recall, note_with_offsets_f1,
          note_with_offsets_velocity_precision,
          note_with_offsets_velocity_recall, note_with_offsets_velocity_f1,
          processed_frame_predictions)
Example #6
0
def _calculate_metrics_py(
    frame_predictions, onset_predictions, offset_predictions, velocity_values,
    sequence_label_str, frame_labels, sequence_id, hparams, min_pitch,
    max_pitch):
  """Python logic for calculating metrics on a single example."""
  tf.logging.info('Calculating metrics for %s with length %d', sequence_id,
                  frame_labels.shape[0])
  if not hparams.predict_onset_threshold:
    onset_predictions = None
  if not hparams.predict_offset_threshold:
    offset_predictions = None

  sequence_prediction = sequences_lib.pianoroll_to_note_sequence(
      frames=frame_predictions,
      frames_per_second=data.hparams_frames_per_second(hparams),
      min_duration_ms=0,
      min_midi_pitch=min_pitch,
      onset_predictions=onset_predictions,
      offset_predictions=offset_predictions,
      velocity_values=velocity_values)

  sequence_label = music_pb2.NoteSequence.FromString(sequence_label_str)

  if hparams.backward_shift_amount_ms:

    def shift_notesequence(ns_time):
      return ns_time + hparams.backward_shift_amount_ms / 1000.

    shifted_sequence_label, skipped_notes = (
        sequences_lib.adjust_notesequence_times(sequence_label,
                                                shift_notesequence))
    assert skipped_notes == 0
    sequence_label = shifted_sequence_label

  est_intervals, est_pitches, est_velocities = (
      infer_util.sequence_to_valued_intervals(sequence_prediction))

  ref_intervals, ref_pitches, ref_velocities = (
      infer_util.sequence_to_valued_intervals(sequence_label))

  note_precision, note_recall, note_f1, _ = (
      mir_eval.transcription.precision_recall_f1_overlap(
          ref_intervals,
          pretty_midi.note_number_to_hz(ref_pitches),
          est_intervals,
          pretty_midi.note_number_to_hz(est_pitches),
          offset_ratio=None))

  (note_with_offsets_precision, note_with_offsets_recall, note_with_offsets_f1,
   _) = (
       mir_eval.transcription.precision_recall_f1_overlap(
           ref_intervals, pretty_midi.note_number_to_hz(ref_pitches),
           est_intervals, pretty_midi.note_number_to_hz(est_pitches)))

  (note_with_offsets_velocity_precision, note_with_offsets_velocity_recall,
   note_with_offsets_velocity_f1, _) = (
       mir_eval.transcription_velocity.precision_recall_f1_overlap(
           ref_intervals=ref_intervals,
           ref_pitches=pretty_midi.note_number_to_hz(ref_pitches),
           ref_velocities=ref_velocities,
           est_intervals=est_intervals,
           est_pitches=pretty_midi.note_number_to_hz(est_pitches),
           est_velocities=est_velocities))

  processed_frame_predictions = sequences_lib.sequence_to_pianoroll(
      sequence_prediction,
      frames_per_second=data.hparams_frames_per_second(hparams),
      min_pitch=min_pitch, max_pitch=max_pitch).active

  if processed_frame_predictions.shape[0] < frame_labels.shape[0]:
    # Pad transcribed frames with silence.
    pad_length = frame_labels.shape[0] - processed_frame_predictions.shape[0]
    processed_frame_predictions = np.pad(processed_frame_predictions,
                                         [(0, pad_length), (0, 0)], 'constant')
  elif processed_frame_predictions.shape[0] > frame_labels.shape[0]:
    # Truncate transcribed frames.
    processed_frame_predictions = (
        processed_frame_predictions[:frame_labels.shape[0], :])

  tf.logging.info(
      'Metrics for %s: Note F1 %f, Note w/ offsets F1 %f, '
      'Note w/ offsets & velocity: %f', sequence_id, note_f1,
      note_with_offsets_f1, note_with_offsets_velocity_f1)
  return (note_precision, note_recall, note_f1, note_with_offsets_precision,
          note_with_offsets_recall, note_with_offsets_f1,
          note_with_offsets_velocity_precision,
          note_with_offsets_velocity_recall, note_with_offsets_velocity_f1,
          processed_frame_predictions)