Пример #1
0
def load_file(
    filename,
    pr_min_pitch=MIN_PITCH_DEFAULT,
    pr_max_pitch=MAX_PITCH_DEFAULT,
    pr_time_increment=40,
):
    """
    Load the given filename into a pandas dataframe.

    Parameters
    ----------
    filename : string
        The file to load into a dataframe.

    pr_min_pitch : int
        The minimum pitch for any piano roll, inclusive.

    pr_max_pitch : int
        The maximum pitch for any piano roll, inclusive.

    pr_time_increment : int
        The length of each frame of any piano roll.

    Return
    ------
    df : pandas dataframe
        A pandas dataframe representing the music from the given file.
    """
    ext = os.path.splitext(os.path.basename(filename))[1]

    if ext == ".mid":
        return fileio.midi_to_df(filename)

    if ext == ".csv":
        return fileio.csv_to_df(filename)

    if ext == ".pkl":
        with open(filename, "rb") as file:
            pkl = pickle.load(file)

        piano_roll = pkl["piano_roll"]

        if piano_roll.shape[1] == (pr_min_pitch - pr_max_pitch + 1):
            # Normal piano roll only -- no onsets
            note_pr = piano_roll.astype(int)
            onset_pr = (np.roll(note_pr, 1, axis=0) - note_pr) == -1
            onset_pr[0] = note_pr[0]
            onset_pr = onset_pr.astype(int)

        elif piano_roll.shape[1] == 2 * (pr_min_pitch - pr_max_pitch + 1):
            # Piano roll with onsets
            note_pr = piano_roll[:, :piano_roll.shape[1] / 2].astype(int)
            onset_pr = piano_roll[:, piano_roll.shape[1] / 2:].astype(int)

        else:
            raise ValueError("Piano roll dimension 2 size ("
                             f"{piano_roll.shape[1]}) must be equal to 1 or 2"
                             f" times the given pitch range [{pr_min_pitch} - "
                             f"{pr_max_pitch}] = "
                             f"{pr_min_pitch - pr_max_pitch + 1}")

        piano_roll = np.vstack((note_pr, onset_pr))
        return formatters.double_pianoroll_to_df(
            piano_roll,
            min_pitch=pr_min_pitch,
            max_pitch=pr_max_pitch,
            time_increment=pr_time_increment,
        )

    raise NotImplementedError(f"Extension {ext} not supported.")
def test_csv_to_midi():
    df = pd.DataFrame(
        {
            "onset": 0,
            "track": [0, 0, 1],
            "pitch": [10, 20, 30],
            "dur": 1000,
            "velocity": 50,
        }
    )
    fileio.df_to_csv(df, "test.csv")

    # Test basic writing
    fileio.csv_to_midi("test.csv", "test.mid")
    assert fileio.midi_to_df("test.mid").equals(
        df
    ), "Writing df to MIDI and reading changes df."

    # Test that writing should overwrite existing notes
    df.pitch += 10
    fileio.df_to_csv(df, "test.csv")
    fileio.csv_to_midi("test.csv", "test2.mid", existing_midi_path="test.mid")
    assert fileio.midi_to_df("test2.mid").equals(
        df
    ), "Writing df to MIDI with existing MIDI does not overwrite notes."

    # Test that writing skips non-overwritten notes
    fileio.csv_to_midi(
        "test.csv", "test2.mid", existing_midi_path="test.mid", excerpt_start=1000
    )
    expected = pd.DataFrame(
        {
            "onset": [0, 0, 0, 1000, 1000, 1000],
            "track": [0, 0, 1, 0, 0, 1],
            "pitch": [10, 20, 30, 20, 30, 40],
            "dur": 1000,
            "velocity": 50,
        }
    )
    assert fileio.midi_to_df("test2.mid").equals(
        expected
    ), "Writing to MIDI doesn't copy notes before excerpt_start"

    # Test that writing skips non-overwritten notes past end
    fileio.csv_to_midi(
        "test.csv", "test.mid", existing_midi_path="test2.mid", excerpt_length=1000
    )
    expected = pd.DataFrame(
        {
            "onset": [0, 0, 0, 1000, 1000, 1000],
            "track": [0, 0, 1, 0, 0, 1],
            "pitch": [20, 30, 40, 20, 30, 40],
            "dur": 1000,
            "velocity": 50,
        }
    )
    assert fileio.midi_to_df("test.mid").equals(
        expected
    ), "Writing to MIDI doesn't copy notes after excerpt_length"

    df.track = 2
    fileio.df_to_csv(df, "test.csv")
    fileio.csv_to_midi(
        "test.csv", "test.mid", existing_midi_path="test2.mid", excerpt_length=1000
    )
    expected = pd.DataFrame(
        {
            "onset": [0, 0, 0, 1000, 1000, 1000],
            "track": [2, 2, 2, 0, 0, 1],
            "pitch": [20, 30, 40, 20, 30, 40],
            "dur": 1000,
            "velocity": 50,
        }
    )
    assert fileio.midi_to_df("test.mid").equals(
        expected
    ), "Writing to MIDI with extra track breaks"

    csv_path = "test.csv"
    midi_path = "test.mid"
    fileio.df_to_csv(CLEAN_INPUT_DF, csv_path)
    # Some less robust tests regarding single_track and non_overlapping
    # (Robust versions will be in the *_to_df and df_to_* functions)
    for (track, overlap) in itertools.product([False, True], repeat=2):
        kwargs = {"single_track": track, "non_overlapping": overlap}

        df = fileio.csv_to_df(csv_path, **kwargs)
        fileio.df_to_midi(df, midi_path)
        correct = fileio.midi_to_df(midi_path)

        fileio.csv_to_midi(csv_path, midi_path, **kwargs)
        res = fileio.midi_to_df(midi_path)

        assert res.equals(
            correct
        ), f"csv_to_midi not using args correctly with args={kwargs}"

    for filename in ["test.mid", "test2.mid", "test.csv"]:
        try:
            os.remove(filename)
        except Exception:
            pass
def test_df_to_midi():
    df = pd.DataFrame(
        {
            "onset": 0,
            "track": [0, 0, 1],
            "pitch": [10, 20, 30],
            "dur": 1000,
            "velocity": 50,
        }
    )

    # Test basic writing
    fileio.df_to_midi(df, "test.mid")
    assert fileio.midi_to_df("test.mid").equals(
        df
    ), "Writing df to MIDI and reading changes df."

    # Test that writing should overwrite existing notes
    df.pitch += 10
    fileio.df_to_midi(df, "test2.mid", existing_midi_path="test.mid")
    assert fileio.midi_to_df("test2.mid").equals(
        df
    ), "Writing df to MIDI with existing MIDI does not overwrite notes."

    # Test that writing skips non-overwritten notes
    fileio.df_to_midi(
        df, "test2.mid", existing_midi_path="test.mid", excerpt_start=1000
    )
    expected = pd.DataFrame(
        {
            "onset": [0, 0, 0, 1000, 1000, 1000],
            "track": [0, 0, 1, 0, 0, 1],
            "pitch": [10, 20, 30, 20, 30, 40],
            "dur": 1000,
            "velocity": 50,
        }
    )
    assert fileio.midi_to_df("test2.mid").equals(
        expected
    ), "Writing to MIDI doesn't copy notes before excerpt_start"

    # Test that writing skips non-overwritten notes past end
    fileio.df_to_midi(
        df, "test.mid", existing_midi_path="test2.mid", excerpt_length=1000
    )
    expected = pd.DataFrame(
        {
            "onset": [0, 0, 0, 1000, 1000, 1000],
            "track": [0, 0, 1, 0, 0, 1],
            "pitch": [20, 30, 40, 20, 30, 40],
            "dur": 1000,
            "velocity": 50,
        }
    )
    assert fileio.midi_to_df("test.mid").equals(
        expected
    ), "Writing to MIDI doesn't copy notes after excerpt_length"

    df.track = 2
    fileio.df_to_midi(
        df, "test.mid", existing_midi_path="test2.mid", excerpt_length=1000
    )
    expected = pd.DataFrame(
        {
            "onset": [0, 0, 0, 1000, 1000, 1000],
            "track": [2, 2, 2, 0, 0, 1],
            "pitch": [20, 30, 40, 20, 30, 40],
            "dur": 1000,
            "velocity": 50,
        }
    )
    assert fileio.midi_to_df("test.mid").equals(
        expected
    ), "Writing to MIDI with extra track breaks"

    # Check all non-note events
    midi_obj = pretty_midi.PrettyMIDI("test.mid")
    midi_obj.instruments[0].name = "test"
    midi_obj.instruments[0].program = 100
    midi_obj.instruments[0].is_drum = True
    midi_obj.instruments[0].pitch_bends.append(pretty_midi.PitchBend(10, 0))
    midi_obj.instruments[0].control_changes.append(pretty_midi.ControlChange(10, 10, 0))
    midi_obj.lyrics.append(pretty_midi.Lyric("test", 0))
    midi_obj.time_signature_changes.append(pretty_midi.TimeSignature(2, 4, 1))
    midi_obj.key_signature_changes.append(pretty_midi.KeySignature(5, 1))
    midi_obj.write("test.mid")

    fileio.df_to_midi(expected, "test2.mid", existing_midi_path="test.mid")
    assert fileio.midi_to_df("test2.mid").equals(expected)

    # Check non-note events and data here
    new_midi = pretty_midi.PrettyMIDI("test2.mid")

    for instrument, new_instrument in zip(midi_obj.instruments, new_midi.instruments):
        assert instrument.name == new_instrument.name
        assert instrument.program == new_instrument.program
        assert instrument.is_drum == new_instrument.is_drum
        for pb, new_pb in zip(instrument.pitch_bends, new_instrument.pitch_bends):
            assert pb.pitch == new_pb.pitch
            assert pb.time == new_pb.time
        for cc, new_cc in zip(
            instrument.control_changes, new_instrument.control_changes
        ):
            assert cc.number == new_cc.number
            assert cc.value == new_cc.value
            assert cc.time == new_cc.time

    for ks, new_ks in zip(
        midi_obj.key_signature_changes, new_midi.key_signature_changes
    ):
        assert ks.key_number == new_ks.key_number
        assert ks.time == new_ks.time

    for lyric, new_lyric in zip(midi_obj.lyrics, new_midi.lyrics):
        assert lyric.text == new_lyric.text
        assert lyric.time == new_lyric.time

    for ts, new_ts in zip(
        midi_obj.time_signature_changes, new_midi.time_signature_changes
    ):
        assert ts.numerator == new_ts.numerator
        assert ts.denominator == new_ts.denominator
        assert ts.time == new_ts.time

    for filename in ["test.mid", "test2.mid"]:
        try:
            os.remove(filename)
        except Exception:
            pass
def test_midi_dir_to_csv():
    basenames = ["test", "test2", "alb_se2"]
    midi_dir = os.path.dirname(TEST_MID)
    midi_paths = [os.path.join(midi_dir, f"{name}.mid") for name in basenames]
    csv_dir = TEST_CACHE_PATH
    csv_paths = [os.path.join(csv_dir, f"{name}.csv") for name in basenames]

    for csv_path in csv_paths:
        try:
            os.remove(csv_path)
        except Exception:
            pass

    midi2_path = os.path.dirname(TEST_MID) + os.path.sep + "test2.mid"
    shutil.copyfile(TEST_MID, midi2_path)

    fileio.midi_dir_to_csv(midi_dir, csv_dir)

    for csv_path in csv_paths:
        assert os.path.exists(csv_path), f"{csv_path} was not created."

    # This relies on pretty_midi being correct
    m = pretty_midi.PrettyMIDI(TEST_MID)
    midi_notes = []
    midi_notes2 = []
    for i, instrument in enumerate(m.instruments):
        for note in instrument.notes:
            midi_notes.append(
                {
                    "onset": int(round(note.start * 1000)),
                    "track": i,
                    "pitch": note.pitch,
                    "dur": int(round(note.end * 1000) - round(note.start * 1000)),
                    "velocity": note.velocity,
                }
            )
            midi_notes2.append(midi_notes[-1])

    # Check that notes were written correctly
    for csv_path, notes in zip(csv_paths, [midi_notes, midi_notes2]):
        with open(csv_path, "r") as file:
            for i, line in enumerate(file):
                split = line.split(",")
                note = {
                    "onset": int(split[0]),
                    "track": int(split[1]),
                    "pitch": int(split[2]),
                    "dur": int(split[3]),
                    "velocity": int(split[4]),
                }
                assert note in notes, (
                    f"csv note {note} not in list "
                    + "of MIDI notes from pretty_midi "
                    + "(or was duplicated)."
                )
                notes.remove(note)

    # Test that all MIDI notes were in the df
    for notes in [midi_notes, midi_notes2]:
        assert len(notes) == 0, (
            "Some MIDI notes (from pretty_midi) were "
            + f"not found in the DataFrame: {notes}"
        )

    # Some less robust tests regarding single_track and non_overlapping
    # (Robust versions will be in the *_to_df and df_to_* functions)
    for (track, overlap) in itertools.product([True, False], repeat=2):
        kwargs = {"single_track": track, "non_overlapping": overlap}
        fileio.midi_dir_to_csv(midi_dir, csv_dir, **kwargs)
        for midi_path, csv_path in zip(midi_paths, csv_paths):
            df = fileio.csv_to_df(csv_path, **kwargs)
            assert df.equals(fileio.midi_to_df(midi_path, **kwargs)), (
                "midi_dir_to_csv not using single_track and non_overlapping "
                "correctly."
            )

    os.remove(midi2_path)
def test_midi_rounding():
    _ = fileio.midi_to_df(ALB_MID)
def test_midi_to_csv():
    # This method is just calls to midi_to_df and df_to_csv
    csv_path = TEST_CACHE_PATH + os.path.sep + "test.csv"

    try:
        os.remove(csv_path)
    except Exception:
        pass

    fileio.midi_to_csv(TEST_MID, csv_path)

    # This relies on pretty_midi being correct
    m = pretty_midi.PrettyMIDI(TEST_MID)
    midi_notes = []
    for i, instrument in enumerate(m.instruments):
        for note in instrument.notes:
            midi_notes.append(
                {
                    "onset": int(round(note.start * 1000)),
                    "track": i,
                    "pitch": note.pitch,
                    "dur": int(round(note.end * 1000) - round(note.start * 1000)),
                    "velocity": note.velocity,
                }
            )

    # Check that notes were written correctly
    with open(csv_path, "r") as file:
        for i, line in enumerate(file):
            split = line.split(",")
            note = {
                "onset": int(split[0]),
                "track": int(split[1]),
                "pitch": int(split[2]),
                "dur": int(split[3]),
                "velocity": int(split[4]),
            }
            assert note in midi_notes, (
                f"csv note {note} not in list "
                + "of MIDI notes from pretty_midi "
                + "(or was duplicated)."
            )
            midi_notes.remove(note)

    # Test that all MIDI notes were in the df
    assert len(midi_notes) == 0, (
        "Some MIDI notes (from pretty_midi) were "
        f"not found in the DataFrame: {midi_notes}"
    )

    # Some less robust tests regarding single_track and non_overlapping
    # (Robust versions will be in the *_to_df and df_to_* functions)
    midi_path = TEST_MID
    csv_path = "test.csv"
    for (track, overlap) in itertools.product([True, False], repeat=2):
        kwargs = {"single_track": track, "non_overlapping": overlap}
        fileio.midi_to_csv(midi_path, csv_path, **kwargs)
        df = fileio.csv_to_df(csv_path, **kwargs)
        assert df.equals(
            fileio.midi_to_df(midi_path, **kwargs)
        ), "midi_to_csv not using single_track and non_overlapping correctly."

    # Check writing without any directory
    fileio.midi_to_csv(TEST_MID, "test.csv")
    try:
        os.remove("test.csv")
    except Exception:
        pass