Exemplo n.º 1
0
 def predict_sequence(frame_probs, onset_probs, frame_predictions,
                      onset_predictions, offset_predictions, velocity_values,
                      hparams):
   """Predict a single sequence."""
   if hparams.drums_only:
     sequence_prediction = infer_util.predict_sequence(
         frame_probs=frame_probs,
         onset_probs=onset_probs,
         frame_predictions=onset_predictions,
         onset_predictions=onset_predictions,
         offset_predictions=onset_predictions,
         velocity_values=velocity_values,
         min_pitch=constants.MIN_MIDI_PITCH,
         hparams=hparams,
         onsets_only=True)
     for note in sequence_prediction.notes:
       note.is_drum = True
   else:
     sequence_prediction = infer_util.predict_sequence(
         frame_probs=frame_probs,
         onset_probs=onset_probs,
         frame_predictions=frame_predictions,
         onset_predictions=onset_predictions,
         offset_predictions=offset_predictions,
         velocity_values=velocity_values,
         min_pitch=constants.MIN_MIDI_PITCH,
         hparams=hparams)
   return sequence_prediction.SerializeToString()
Exemplo n.º 2
0
 def _predict(frame_predictions, onset_predictions, offset_predictions,
              velocity_values):
   sequence = infer_util.predict_sequence(
       frame_predictions=frame_predictions,
       onset_predictions=onset_predictions,
       offset_predictions=offset_predictions,
       velocity_values=velocity_values,
       hparams=hparams, min_pitch=constants.MIN_MIDI_PITCH)
   return sequence.SerializeToString()
Exemplo n.º 3
0
 def _predict(frame_probs, onset_probs, frame_predictions,
              onset_predictions, offset_predictions, velocity_values):
     """Convert frame predictions into a sequence (Python)."""
     sequence = infer_util.predict_sequence(
         frame_probs=frame_probs,
         onset_probs=onset_probs,
         frame_predictions=frame_predictions,
         onset_predictions=onset_predictions,
         offset_predictions=offset_predictions,
         velocity_values=velocity_values,
         hparams=hparams,
         min_pitch=constants.MIN_MIDI_PITCH)
     return sequence.SerializeToString()
def magenta_decoding(onset_prob,
                     frame_prob,
                     offset_prob,
                     threshold=0.5,
                     viterbi=False):
    config_map = configs.CONFIG_MAP
    config = config_map['onsets_frames']
    hparams = config.hparams
    if viterbi:
        hparams.viterbi_decoding = True
    seq = predict_sequence(
        frame_prob,
        onset_prob,
        frame_prob > threshold,
        onset_prob > threshold,
        offset_prob > 0.0,
        # offset_prob > threshold,
        velocity_values=None,
        hparams=hparams,
        min_pitch=21)
    return seq
Exemplo n.º 5
0
def _calculate_metrics_py(frame_probs,
                          onset_probs,
                          frame_predictions,
                          onset_predictions,
                          offset_predictions,
                          velocity_values,
                          sequence_label_str,
                          frame_labels,
                          sequence_id,
                          hparams,
                          min_pitch,
                          max_pitch,
                          onsets_only,
                          restrict_to_pitch=None):
    """Python logic for calculating metrics on a single example."""
    tf.logging.info('Calculating metrics for %s with length %d', sequence_id,
                    frame_labels.shape[0])

    sequence_prediction = infer_util.predict_sequence(
        frame_probs=frame_probs,
        onset_probs=onset_probs,
        frame_predictions=frame_predictions,
        onset_predictions=onset_predictions,
        offset_predictions=offset_predictions,
        velocity_values=velocity_values,
        min_pitch=min_pitch,
        hparams=hparams,
        onsets_only=onsets_only)

    sequence_label = music_pb2.NoteSequence.FromString(sequence_label_str)

    if hparams.backward_shift_amount_ms:
        def shift_notesequence(ns_time):
            return ns_time + hparams.backward_shift_amount_ms / 1000.

        shifted_sequence_label, skipped_notes = (
            sequences_lib.adjust_notesequence_times(sequence_label,
                                                    shift_notesequence))
        assert skipped_notes == 0
        sequence_label = shifted_sequence_label

    est_intervals, est_pitches, est_velocities = (
        sequence_to_valued_intervals(
            sequence_prediction, restrict_to_pitch=restrict_to_pitch))

    ref_intervals, ref_pitches, ref_velocities = (
        sequence_to_valued_intervals(
            sequence_label, restrict_to_pitch=restrict_to_pitch))

    processed_frame_predictions = sequences_lib.sequence_to_pianoroll(
        sequence_prediction,
        frames_per_second=data.hparams_frames_per_second(hparams),
        min_pitch=min_pitch,
        max_pitch=max_pitch).active

    if processed_frame_predictions.shape[0] < frame_labels.shape[0]:
        # Pad transcribed frames with silence.
        pad_length = frame_labels.shape[0] - processed_frame_predictions.shape[0]
        processed_frame_predictions = np.pad(processed_frame_predictions,
                                             [(0, pad_length), (0, 0)], 'constant')
    elif processed_frame_predictions.shape[0] > frame_labels.shape[0]:
        # Truncate transcribed frames.
        processed_frame_predictions = (
            processed_frame_predictions[:frame_labels.shape[0], :])

    if len(ref_pitches) == 0:
        tf.logging.info(
            'Reference pitches were length 0, returning empty metrics for %s:',
            sequence_id)
        return tuple([[]] * 12 + [processed_frame_predictions])

    note_precision, note_recall, note_f1, _ = (
        mir_eval.transcription.precision_recall_f1_overlap(
            ref_intervals,
            pretty_midi.note_number_to_hz(ref_pitches),
            est_intervals,
            pretty_midi.note_number_to_hz(est_pitches),
            offset_ratio=None))

    (note_with_velocity_precision, note_with_velocity_recall,
     note_with_velocity_f1, _) = (
        mir_eval.transcription_velocity.precision_recall_f1_overlap(
            ref_intervals=ref_intervals,
            ref_pitches=pretty_midi.note_number_to_hz(ref_pitches),
            ref_velocities=ref_velocities,
            est_intervals=est_intervals,
            est_pitches=pretty_midi.note_number_to_hz(est_pitches),
            est_velocities=est_velocities,
            offset_ratio=None))

    (note_with_offsets_precision, note_with_offsets_recall, note_with_offsets_f1,
     _) = (
        mir_eval.transcription.precision_recall_f1_overlap(
            ref_intervals, pretty_midi.note_number_to_hz(ref_pitches),
            est_intervals, pretty_midi.note_number_to_hz(est_pitches)))

    (note_with_offsets_velocity_precision, note_with_offsets_velocity_recall,
     note_with_offsets_velocity_f1, _) = (
        mir_eval.transcription_velocity.precision_recall_f1_overlap(
            ref_intervals=ref_intervals,
            ref_pitches=pretty_midi.note_number_to_hz(ref_pitches),
            ref_velocities=ref_velocities,
            est_intervals=est_intervals,
            est_pitches=pretty_midi.note_number_to_hz(est_pitches),
            est_velocities=est_velocities))

    tf.logging.info(
        'Metrics for %s: Note F1 %f, Note w/ velocity F1 %f, Note w/ offsets F1 %f, '
        'Note w/ offsets & velocity: %f', sequence_id, note_f1,
        note_with_velocity_f1, note_with_offsets_f1,
        note_with_offsets_velocity_f1)
    # Return 1-d tensors for the metrics
    return ([note_precision], [note_recall], [note_f1],
            [note_with_velocity_precision], [note_with_velocity_recall],
            [note_with_velocity_f1], [note_with_offsets_precision],
            [note_with_offsets_recall], [note_with_offsets_f1
                                         ], [note_with_offsets_velocity_precision],
            [note_with_offsets_velocity_recall], [note_with_offsets_velocity_f1
                                                  ], [processed_frame_predictions])