Пример #1
0
def create_corpus_csvs(acme_dir, format_dict):
    """
    From a given acme dataset, create formatted csv files to use with
    our provided pytorch Dataset classes.

    Parameters
    ----------
    acme_dir : string
        The directory containing the acme data.

    format_dict: dict
        A dictionary (likely one provided in FORMATTERS), containing at least:
        name : string
            The name to print in the loading message.
        prefix : string
            The string to prepend to "_corpus_path" and "_corpus_lin_nr" columns
            in the resulting metadata.csv file, as well as to use in the names
            of the resulting corpus-specific csv files like:
            {split}_{prefix}_corpus.csv
        df_to_str : function
            The function to convert from a pandas DataFrame to a string in the
            desired format.
    """
    name = format_dict["name"]
    prefix = format_dict["prefix"]
    df_converter_func = format_dict["df_to_str"]
    fh_dict = {
        split: open(os.path.join(acme_dir, f"{split}_{prefix}_corpus.csv"),
                    "w")
        for split in ["train", "valid", "test"]
    }
    line_counts = {split: 0 for split in ["train", "valid", "test"]}
    meta_df = pd.read_csv(os.path.join(acme_dir, "metadata.csv"))
    for idx, row in tqdm.tqdm(meta_df.iterrows(),
                              total=meta_df.shape[0],
                              desc=f"Creating {name} corpus"):
        alt_df = csv_to_df(os.path.join(acme_dir, row.altered_csv_path))
        alt_str = df_converter_func(alt_df)
        clean_df = csv_to_df(os.path.join(acme_dir, row.clean_csv_path))
        clean_str = df_converter_func(clean_df)
        deg_num = row.degradation_id
        split = row.split
        fh = fh_dict[split]
        fh.write(f"{alt_str},{clean_str},{deg_num}\n")
        meta_df.loc[idx, f"{prefix}_corpus_path"] = os.path.basename(fh.name)
        meta_df.loc[idx, f"{prefix}_corpus_line_nr"] = line_counts[split]
        line_counts[split] += 1
    meta_df.loc[:, f"{prefix}_corpus_line_nr"] = meta_df[
        f"{prefix}_corpus_line_nr"].astype(int)
    meta_df.to_csv(os.path.join(acme_dir, "metadata.csv"), index=False)
def test_csv_to_df():
    csv_path = os.path.join(TEST_CACHE_PATH, "test.csv")
    fileio.df_to_csv(CLEAN_INPUT_DF, csv_path)

    # Check clean_df args
    for (track, overlap) in itertools.product([False, True], repeat=2):
        kwargs = {"single_track": track, "non_overlapping": overlap}

        correct = CLEAN_RES_DFS[track][overlap]
        res = fileio.csv_to_df(csv_path, **kwargs)

        assert res.equals(correct), f"csv_to_df result incorrect with args={kwargs}"

    # Check inducing velocity = 100 if not present in csv
    CLEAN_INPUT_DF[NOTE_DF_SORT_ORDER[:-1]].to_csv(csv_path, index=None, header=False)

    res = fileio.csv_to_df(csv_path)
    correct = CLEAN_RES_DFS[False][False]
    assert res[NOTE_DF_SORT_ORDER[:-1]].equals(correct[NOTE_DF_SORT_ORDER[:-1]])
    assert all(res["velocity"] == 100)
    assert res["velocity"].dtype == "int64"
def test_csv_to_df():
    csv_path = os.path.join(TEST_CACHE_PATH, "test.csv")
    fileio.df_to_csv(CLEAN_INPUT_DF, csv_path)

    # Check clean_df args
    for (track, overlap) in itertools.product([False, True], repeat=2):
        kwargs = {"single_track": track, "non_overlapping": overlap}

        correct = CLEAN_RES_DFS[track][overlap]
        res = fileio.csv_to_df(csv_path, **kwargs)

        assert res.equals(
            correct), f"csv_to_df result incorrect with args={kwargs}"
def test_csv_to_midi():
    df = pd.DataFrame({
        "onset": 0,
        "track": [0, 0, 1],
        "pitch": [10, 20, 30],
        "dur": 1000
    })
    fileio.df_to_csv(df, "test.csv")

    # Test basic writing
    fileio.csv_to_midi("test.csv", "test.mid")
    assert fileio.midi_to_df("test.mid").equals(
        df), "Writing df to MIDI and reading changes df."

    # Test that writing should overwrite existing notes
    df.pitch += 10
    fileio.df_to_csv(df, "test.csv")
    fileio.csv_to_midi("test.csv", "test2.mid", existing_midi_path="test.mid")
    assert fileio.midi_to_df("test2.mid").equals(
        df), "Writing df to MIDI with existing MIDI does not overwrite notes."

    # Test that writing skips non-overwritten notes
    fileio.csv_to_midi("test.csv",
                       "test2.mid",
                       existing_midi_path="test.mid",
                       excerpt_start=1000)
    expected = pd.DataFrame({
        "onset": [0, 0, 0, 1000, 1000, 1000],
        "track": [0, 0, 1, 0, 0, 1],
        "pitch": [10, 20, 30, 20, 30, 40],
        "dur": 1000,
    })
    assert fileio.midi_to_df("test2.mid").equals(
        expected), "Writing to MIDI doesn't copy notes before excerpt_start"

    # Test that writing skips non-overwritten notes past end
    fileio.csv_to_midi("test.csv",
                       "test.mid",
                       existing_midi_path="test2.mid",
                       excerpt_length=1000)
    expected = pd.DataFrame({
        "onset": [0, 0, 0, 1000, 1000, 1000],
        "track": [0, 0, 1, 0, 0, 1],
        "pitch": [20, 30, 40, 20, 30, 40],
        "dur": 1000,
    })
    assert fileio.midi_to_df("test.mid").equals(
        expected), "Writing to MIDI doesn't copy notes after excerpt_length"

    df.track = 2
    fileio.df_to_csv(df, "test.csv")
    fileio.csv_to_midi("test.csv",
                       "test.mid",
                       existing_midi_path="test2.mid",
                       excerpt_length=1000)
    expected = pd.DataFrame({
        "onset": [0, 0, 0, 1000, 1000, 1000],
        "track": [2, 2, 2, 0, 0, 1],
        "pitch": [20, 30, 40, 20, 30, 40],
        "dur": 1000,
    })
    assert fileio.midi_to_df("test.mid").equals(
        expected), "Writing to MIDI with extra track breaks"

    csv_path = "test.csv"
    midi_path = "test.mid"
    fileio.df_to_csv(CLEAN_INPUT_DF, csv_path)
    # Some less robust tests regarding single_track and non_overlapping
    # (Robust versions will be in the *_to_df and df_to_* functions)
    for (track, overlap) in itertools.product([False, True], repeat=2):
        kwargs = {"single_track": track, "non_overlapping": overlap}

        df = fileio.csv_to_df(csv_path, **kwargs)
        fileio.df_to_midi(df, midi_path)
        correct = fileio.midi_to_df(midi_path)

        fileio.csv_to_midi(csv_path, midi_path, **kwargs)
        res = fileio.midi_to_df(midi_path)

        assert res.equals(
            correct
        ), f"csv_to_midi not using args correctly with args={kwargs}"

    for filename in ["test.mid", "test2.mid", "test.csv"]:
        try:
            os.remove(filename)
        except Exception:
            pass
def test_midi_dir_to_csv():
    basenames = ["test", "test2", "alb_se2"]
    midi_dir = os.path.dirname(TEST_MID)
    midi_paths = [os.path.join(midi_dir, f"{name}.mid") for name in basenames]
    csv_dir = TEST_CACHE_PATH
    csv_paths = [os.path.join(csv_dir, f"{name}.csv") for name in basenames]

    for csv_path in csv_paths:
        try:
            os.remove(csv_path)
        except Exception:
            pass

    midi2_path = os.path.dirname(TEST_MID) + os.path.sep + "test2.mid"
    shutil.copyfile(TEST_MID, midi2_path)

    fileio.midi_dir_to_csv(midi_dir, csv_dir)

    for csv_path in csv_paths:
        assert os.path.exists(csv_path), f"{csv_path} was not created."

    # This relies on pretty_midi being correct
    m = pretty_midi.PrettyMIDI(TEST_MID)
    midi_notes = []
    midi_notes2 = []
    for i, instrument in enumerate(m.instruments):
        for note in instrument.notes:
            midi_notes.append({
                "onset":
                int(round(note.start * 1000)),
                "track":
                i,
                "pitch":
                note.pitch,
                "dur":
                int(round(note.end * 1000) - round(note.start * 1000)),
            })
            midi_notes2.append(midi_notes[-1])

    # Check that notes were written correctly
    for csv_path, notes in zip(csv_paths, [midi_notes, midi_notes2]):
        with open(csv_path, "r") as file:
            for i, line in enumerate(file):
                split = line.split(",")
                note = {
                    "onset": int(split[0]),
                    "track": int(split[1]),
                    "pitch": int(split[2]),
                    "dur": int(split[3]),
                }
                assert note in notes, (f"csv note {note} not in list " +
                                       "of MIDI notes from pretty_midi " +
                                       "(or was duplicated).")
                notes.remove(note)

    # Test that all MIDI notes were in the df
    for notes in [midi_notes, midi_notes2]:
        assert len(notes) == 0, ("Some MIDI notes (from pretty_midi) were " +
                                 f"not found in the DataFrame: {notes}")

    # Some less robust tests regarding single_track and non_overlapping
    # (Robust versions will be in the *_to_df and df_to_* functions)
    for (track, overlap) in itertools.product([True, False], repeat=2):
        kwargs = {"single_track": track, "non_overlapping": overlap}
        fileio.midi_dir_to_csv(midi_dir, csv_dir, **kwargs)
        for midi_path, csv_path in zip(midi_paths, csv_paths):
            df = fileio.csv_to_df(csv_path, **kwargs)
            assert df.equals(fileio.midi_to_df(midi_path, **kwargs)), (
                "midi_dir_to_csv not using single_track and non_overlapping "
                "correctly.")

    os.remove(midi2_path)
def test_midi_to_csv():
    # This method is just calls to midi_to_df and df_to_csv
    csv_path = TEST_CACHE_PATH + os.path.sep + "test.csv"

    try:
        os.remove(csv_path)
    except Exception:
        pass

    fileio.midi_to_csv(TEST_MID, csv_path)

    # This relies on pretty_midi being correct
    m = pretty_midi.PrettyMIDI(TEST_MID)
    midi_notes = []
    for i, instrument in enumerate(m.instruments):
        for note in instrument.notes:
            midi_notes.append({
                "onset":
                int(round(note.start * 1000)),
                "track":
                i,
                "pitch":
                note.pitch,
                "dur":
                int(round(note.end * 1000) - round(note.start * 1000)),
            })

    # Check that notes were written correctly
    with open(csv_path, "r") as file:
        for i, line in enumerate(file):
            split = line.split(",")
            note = {
                "onset": int(split[0]),
                "track": int(split[1]),
                "pitch": int(split[2]),
                "dur": int(split[3]),
            }
            assert note in midi_notes, (f"csv note {note} not in list " +
                                        "of MIDI notes from pretty_midi " +
                                        "(or was duplicated).")
            midi_notes.remove(note)

    # Test that all MIDI notes were in the df
    assert len(midi_notes) == 0, ("Some MIDI notes (from pretty_midi) were "
                                  f"not found in the DataFrame: {midi_notes}")

    # Some less robust tests regarding single_track and non_overlapping
    # (Robust versions will be in the *_to_df and df_to_* functions)
    midi_path = TEST_MID
    csv_path = "test.csv"
    for (track, overlap) in itertools.product([True, False], repeat=2):
        kwargs = {"single_track": track, "non_overlapping": overlap}
        fileio.midi_to_csv(midi_path, csv_path, **kwargs)
        df = fileio.csv_to_df(csv_path, **kwargs)
        assert df.equals(
            fileio.midi_to_df(midi_path, **kwargs)
        ), "midi_to_csv not using single_track and non_overlapping correctly."

    # Check writing without any directory
    fileio.midi_to_csv(TEST_MID, "test.csv")
    try:
        os.remove("test.csv")
    except Exception:
        pass
Пример #7
0
def load_file(
    filename,
    pr_min_pitch=MIN_PITCH_DEFAULT,
    pr_max_pitch=MAX_PITCH_DEFAULT,
    pr_time_increment=40,
):
    """
    Load the given filename into a pandas dataframe.

    Parameters
    ----------
    filename : string
        The file to load into a dataframe.

    pr_min_pitch : int
        The minimum pitch for any piano roll, inclusive.

    pr_max_pitch : int
        The maximum pitch for any piano roll, inclusive.

    pr_time_increment : int
        The length of each frame of any piano roll.

    Return
    ------
    df : pandas dataframe
        A pandas dataframe representing the music from the given file.
    """
    ext = os.path.splitext(os.path.basename(filename))[1]

    if ext == ".mid":
        return fileio.midi_to_df(filename)

    if ext == ".csv":
        return fileio.csv_to_df(filename)

    if ext == ".pkl":
        with open(filename, "rb") as file:
            pkl = pickle.load(file)

        piano_roll = pkl["piano_roll"]

        if piano_roll.shape[1] == (pr_min_pitch - pr_max_pitch + 1):
            # Normal piano roll only -- no onsets
            note_pr = piano_roll.astype(int)
            onset_pr = (np.roll(note_pr, 1, axis=0) - note_pr) == -1
            onset_pr[0] = note_pr[0]
            onset_pr = onset_pr.astype(int)

        elif piano_roll.shape[1] == 2 * (pr_min_pitch - pr_max_pitch + 1):
            # Piano roll with onsets
            note_pr = piano_roll[:, :piano_roll.shape[1] / 2].astype(int)
            onset_pr = piano_roll[:, piano_roll.shape[1] / 2:].astype(int)

        else:
            raise ValueError("Piano roll dimension 2 size ("
                             f"{piano_roll.shape[1]}) must be equal to 1 or 2"
                             f" times the given pitch range [{pr_min_pitch} - "
                             f"{pr_max_pitch}] = "
                             f"{pr_min_pitch - pr_max_pitch + 1}")

        piano_roll = np.vstack((note_pr, onset_pr))
        return formatters.double_pianoroll_to_df(
            piano_roll,
            min_pitch=pr_min_pitch,
            max_pitch=pr_max_pitch,
            time_increment=pr_time_increment,
        )

    raise NotImplementedError(f"Extension {ext} not supported.")