Ejemplo n.º 1
0
 def get_track_label(track_label, track=None):
     """Convenient function to get track labels"""
     if track_label == 'name':
         return track.name if track.name != "" else pretty_midi.program_to_instrument_class(
             track.program)
     elif track_label == 'program':
         return pretty_midi.program_to_instrument_name(track.program)
     elif track_label == 'family':
         return pretty_midi.program_to_instrument_class(track.program)
     elif track is None:
         return track_label
Ejemplo n.º 2
0
 def get_track_label(track_label, track=None):
     """Return corresponding track labels."""
     if track_label == "name":
         return track.name
     if track_label == "program":
         return pretty_midi.program_to_instrument_name(track.program)
     if track_label == "family":
         return pretty_midi.program_to_instrument_class(track.program)
     return track_label
Ejemplo n.º 3
0
def get_instrument_info(instrument):
    """Given a pretty_midi.Instrument class instance, return the infomation
    dictionary of the instrument."""
    return {
        'program_num': instrument.program,
        'program_name':
        pretty_midi.program_to_instrument_name(instrument.program),
        'name': instrument.name.strip(),
        'is_drum': instrument.is_drum,
        'family_num': int(instrument.program) // 8,
        'family_name':
        pretty_midi.program_to_instrument_class(instrument.program)
    }
Ejemplo n.º 4
0
def show_pianoroll(xs,
                   min_pitch=45,
                   max_pitch=85,
                   programs=[0, 0, 0, 0],
                   save_dir=None):
    """ Plot a MultiTrack PianoRoll

    :param x: Multi Instrument PianoRoll Tensor
    :param min_pitch: Min Pitch / Min Y across all instruments.
    :param max_pitch: Max Pitch / Max Y across all instruments.
    :param programs: Program Number of the Tracks.
    :param file_name: Optional. File Name to save the plot.
    :return:
    """

    # Convert fake_x to numpy and convert -1 to 0
    xs = xs > 0

    channel_last = lambda x: np.moveaxis(np.array(x), 2, 0)
    xs = [channel_last(x) for x in xs]

    assert len(
        xs[0].shape) == 3, 'Pianoroll shape must have 3 dims, Got %d' % len(
            xs[0].shape)
    n_tracks, time_step, _ = xs[0].shape

    plt.ion()
    fig = plt.figure(figsize=(15, 4))

    x = xs[0]

    for j in range(4):
        b = j + 1
        ax = fig.add_subplot(1, 4, b)
        nz = np.nonzero(x[b - 1])

        if programs:
            ax.set_xlabel(
                'Time(' +
                pretty_midi.program_to_instrument_class(programs[j % 4]) + ')')

        if (j + 1) == 1:
            ax.set_ylabel('Pitch')
        else:
            ax.set_yticks([])

        ax.scatter(nz[0], nz[1], s=np.pi * 3, color='bgrcmk'[b - 1])
        ax.set_ylim(45, 85)
        ax.set_xlim(0, time_step)
        fig.add_subplot(ax)
Ejemplo n.º 5
0
def midi_to_stacked_piano_roll(midi, hop_seconds=1024/4/8000., min_note=20,
                               max_note=100):
    '''
    Converts MIDI data into matrix of stacked piano rolls, one for each
    instrument class.
    Drum instruments and special effects are ignored.

    :parameters:
        - midi : pretty_midi.PrettyMIDI
            MIDI data object
        - hop_seconds : float
            Time between each column in the piano roll matrix
        - min_note : int
            Lowest note in the piano roll to consider.
        - max_note : int
            Highest note in the piano roll to consider.

    :returns:
        - stacked_piano_roll : np.ndarray
            Stacked piano roll representation
    '''
    n_notes = max_note - min_note
    # Start index in the stacked piano roll of each instrument class
    stacked_index = {'Piano' : 0, 'Chromatic Percussion' : n_notes,
                     'Organ' : 2*n_notes, 'Guitar' : 3*n_notes,
                     'Bass' : 4*n_notes, 'Strings' : 5*n_notes,
                     'Ensemble' : 6*n_notes, 'Brass' : 7*n_notes,
                     'Reed' : 8*n_notes, 'Pipe' : 9*n_notes,
                     'Synth Lead' : 10*n_notes, 'Synth Pad' : 11*n_notes,
                     'Ethnic' : 12*n_notes, 'Percussive' : 13*n_notes}
    # Initialize stacked piano roll
    times = np.arange(0, midi.get_end_time(), hop_seconds)
    stacked_piano_roll = np.zeros((14*n_notes, times.shape[0]))
    # This will map program number to the stacked piano roll
    for instrument in midi.instruments:
        ins_class = pretty_midi.program_to_instrument_class(instrument.program)
        # Skip drum and effects instruments
        if instrument.is_drum or \
           'Effects' in ins_class:
            continue
        # Get the piano roll for this instrument
        piano_roll = instrument.get_piano_roll(fs=1./hop_seconds)
        # Determine row and column indices to add in piano roll
        index = stacked_index[ins_class]
        note_range = np.r_[index:index + n_notes]
        n_col = piano_roll.shape[1]
        stacked_piano_roll[note_range, :n_col] += piano_roll[min_note:max_note]
    return stacked_piano_roll
def get_instrument_classes(msd_id) -> Optional[list]:
    """
  Returns the list of instruments classes given by PrettyMIDI for the MSD id.

  :param msd_id: the MSD id
  :return: the list of instruments classes

  """
    midi_md5 = get_matched_midi_md5(msd_id, MSD_SCORE_MATCHES)
    midi_path = get_midi_path(msd_id, midi_md5, args.path_dataset_dir)
    pm = PrettyMIDI(midi_path)
    classes = [
        program_to_instrument_class(instrument.program)
        for instrument in pm.instruments if not instrument.is_drum
    ]
    drums = ["Drums" for instrument in pm.instruments if instrument.is_drum]
    classes = classes + drums
    if not classes:
        raise Exception(f"No program classes for {msd_id}: " f"{len(classes)}")
    return classes
Ejemplo n.º 7
0
                pdf_path = unidecode.unidecode(row['pdf'].split("/")[-1])
                midi_path = unidecode.unidecode(row['midi'].split("/")[-1])

                if os.path.isfile(row['pdf']) and os.path.isfile(row['midi']):
                    try:
                        # Get midi length in seconds
                        midi_data = pretty_midi.PrettyMIDI(row['midi'])
                    except:
                        print("----", "Midi file seems corruct.")
                        continue

                    # Check midi file has only piano tracks
                    non_piano_instruments = 0
                    for inst in midi_data.instruments:
                        # Only consider instruments from the piano family
                        if pretty_midi.program_to_instrument_class(
                                inst.program) != "Piano":
                            non_piano_instruments += 1

                    if non_piano_instruments == 0:
                        # Check midi is not too short
                        midi_length = midi_data.get_end_time()

                        if midi_length > MIN_LENGTH:
                            shutil.copyfile(
                                row['pdf'],
                                os.path.join(opt.out, "pdf", pdf_path))
                            shutil.copyfile(
                                row['midi'],
                                os.path.join(opt.out, "midi", midi_path))

                            row['series'] = unidecode.unidecode(row['series'])
Ejemplo n.º 8
0
def plot_pianoroll(iteration,
                   xs,
                   fake_xs,
                   min_pitch=45,
                   max_pitch=85,
                   programs=[0, 0, 0, 0],
                   save_dir=None):
    """ Plot a MultiTrack PianoRoll

    :param x: Multi Instrument PianoRoll Tensor
    :param min_pitch: Min Pitch / Min Y across all instruments.
    :param max_pitch: Max Pitch / Max Y across all instruments.
    :param programs: Program Number of the Tracks.
    :param file_name: Optional. File Name to save the plot.
    :return:
    """

    # Convert fake_x to numpy and convert -1 to 0
    xs = xs > 0
    fake_xs = fake_xs > 0

    channel_last = lambda x: np.moveaxis(np.array(x), 2, 0)
    xs = [channel_last(x) for x in xs]
    fake_xs = [channel_last(fake_x) for fake_x in fake_xs]

    assert len(
        xs[0].shape) == 3, 'Pianoroll shape must have 3 dims, Got %d' % len(
            xs[0].shape)
    n_tracks, time_step, _ = xs[0].shape

    plt.ion()
    fig = plt.figure(figsize=(15, 8))

    # gridspec inside gridspec
    outer_grid = gridspec.GridSpec(2, 2, wspace=0.1, hspace=0.2)

    for i in range(4):
        inner_grid = gridspec.GridSpecFromSubplotSpec(
            2, n_tracks, subplot_spec=outer_grid[i], wspace=0.0, hspace=0.0)

        x, fake_x = xs[i], fake_xs[i]

        for j, (a, b) in enumerate(itertools.product([1, 2], [1, 2, 3, 4])):

            ax = fig.add_subplot(inner_grid[j])

            if a == 1:
                nz = np.nonzero(x[b - 1])
            else:
                nz = np.nonzero(fake_x[b - 1])

            if programs:
                ax.set_xlabel(
                    'Time(' +
                    pretty_midi.program_to_instrument_class(programs[j % 4]) +
                    ')')

            if b == 1:
                ax.set_ylabel('Pitch')
            else:
                ax.set_yticks([])

            ax.scatter(nz[0], nz[1], s=np.pi * 3, color='bgrcmk'[b - 1])
            ax.set_ylim(45, 85)
            ax.set_xlim(0, time_step)
            fig.add_subplot(ax)

    if isinstance(iteration, int):
        plt.suptitle('iteration: {}'.format(iteration), fontsize=20)
        filename = os.path.join(save_dir,
                                'sample_iteration_%05d.png' % iteration)
    else:
        plt.suptitle('Inference', fontsize=20)
        filename = os.path.join(save_dir, 'sample_inference.png')
    plt.savefig(filename)
    plt.close(fig)