def test_clean_df():
    # Default, no arguments. multi-track, with overlaps
    # But correct sorting, columns, and index
    res = clean_df(CLEAN_INPUT_DF)
    assert res.equals(
        clean_df(CLEAN_INPUT_DF, single_track=False, non_overlapping=False)
    ), "clean_df does not default to False, False for optional args."

    for track, overlap in itertools.product([True, False], repeat=2):
        kwargs = {"single_track": track, "non_overlapping": overlap}
        prior = CLEAN_INPUT_DF.copy()
        res = clean_df(CLEAN_INPUT_DF, **kwargs)
        assert CLEAN_INPUT_DF.equals(prior), "clean_df changed input df"
        assert res.equals(
            CLEAN_RES_DFS[track]
            [overlap]), f"clean_df result incorrect for args: {kwargs}"
def csv_to_df(csv_path, single_track=False, non_overlapping=False):
    """
    Read a csv and create a standard note event DataFrame - a `note_df`.

    Parameters
    ----------
    csv_path : str
        The path of the csv to be imported.

    single_track : boolean
        True to set the track of every note to 0. This will happen before
        overlaps are removed.

    non_overlapping : boolean
        True to remove overlaps from the resulting dataframe by passing the df
        to df_utils.remove_pitch_overlaps. This will create a situation where,
        for every (track, pitch) pair, for any point in time which there is a
        sustained note present in the input, there will be a sustained note
        in the returned df. Likewise for any point with a note onset.

    Returns
    -------
    note_df : pd.DataFrame
        A note_df, in mdtk's standard format. With columns:
            onset (int): onset time of a note, in ms.
            track (int): the track of the note.
            pitch (int): the MIDI pitch of the note.
            dur (int): the duration of the note, in ms.
        Sorting will be first by onset, then track, then pitch, then duration.
    """
    df = pd.read_csv(csv_path, names=NOTE_DF_SORT_ORDER)

    df = clean_df(df, single_track=single_track, non_overlapping=non_overlapping)

    return df
def test_midi_to_df():
    df = fileio.midi_to_df(TEST_MID)

    # This relies on pretty_midi being correct
    m = pretty_midi.PrettyMIDI(TEST_MID)
    midi_notes = []
    for i, instrument in enumerate(m.instruments):
        for note in instrument.notes:
            midi_notes.append(
                {
                    "onset": int(round(note.start * 1000)),
                    "track": i,
                    "pitch": note.pitch,
                    "dur": int(round(note.end * 1000) - round(note.start * 1000)),
                    "velocity": note.velocity,
                }
            )
    midi_df = pd.DataFrame(midi_notes)

    df_notes = df.to_dict("records")

    # Test that all notes df notes are in the MIDI
    for df_note in df_notes:
        assert df_note in midi_notes, (
            f"DataFrame note {df_note} not in "
            + "list of MIDI notes from pretty_midi "
            + "(or was duplicated)."
        )
        midi_notes.remove(df_note)

    # Test that all MIDI notes were in the df
    assert len(midi_notes) == 0, (
        "Some MIDI notes (from pretty_midi) were "
        + f"not found in the DataFrame: {midi_notes}"
    )

    # Check clean_df (args, sorting)
    for (track, overlap) in itertools.product([False, True], repeat=2):
        kwargs = {"single_track": track, "non_overlapping": overlap}

        correct = clean_df(midi_df, **kwargs)
        res = fileio.midi_to_df(TEST_MID, **kwargs)

        assert res.equals(
            correct
        ), f"csv_to_midi not using args correctly with args={kwargs}"
def midi_to_df(midi_path, single_track=False, non_overlapping=False):
    """
    Get the data from a MIDI file and load it into a pandas DataFrame.

    Parameters
    ----------
    midi_path : string
        The filename of the MIDI file to parse.

    single_track : boolean
        True to set the track of every note to 0. This will happen before
        overlaps are removed.

    non_overlapping : boolean
        True to remove overlaps from the resulting dataframe by passing the df
        to df_utils.remove_pitch_overlaps. This will create a situation where,
        for every (track, pitch) pair, for any point in time which there is a
        sustained note present in the input, there will be a sustained note
        in the returned df. Likewise for any point with a note onset.

    Returns
    -------
    df : DataFrame
        A pandas DataFrame containing the notes parsed from the given MIDI
        file. There will be 4 columns:
            onset: Onset time of the note, in milliseconds.
            track: The track number of the instrument the note is from.
            pitch: The MIDI pitch number for the note.
            dur: The duration of the note (offset - onset), in milliseconds.
        Sorting will be first by onset, then track, then pitch, then duration.
    """
    try:
        midi = pretty_midi.PrettyMIDI(midi_path)
    except Exception:
        logging.warning(f"Error parsing midi file {midi_path}. Skipping.")
        return None

    notes = []
    for index, instrument in enumerate(midi.instruments):
        for note in instrument.notes:
            notes.append(
                {
                    "onset": int(round(note.start * 1000)),
                    "track": index,
                    "pitch": note.pitch,
                    "dur": int(round(note.end * 1000) - round(note.start * 1000)),
                }
            )

    if len(notes) == 0:
        logging.warning(
            f"WARNING: the midi file located at {midi_path} is empty. "
            "Returning None.",
        )
        return None

    df = clean_df(
        pd.DataFrame(notes), single_track=single_track, non_overlapping=non_overlapping
    )

    return df