예제 #1
0
def test_tonas_load_label():
    expected = [
        Label(0.233333, 0.366666, 53),
        Label(0.4, 0.583333, 58),
        Label(0.6, 0.75, 58),
        Label(0.766666, 0.95, 58)
    ]
    gt_file_path = "./tests/resource/gt_files/tonas_gt_file.notes.Corrected"
    labels = dset.TonasStructure.load_label(gt_file_path)
    _assert_all_equal(expected, labels)
예제 #2
0
def test_cmedia_load_label():
    expected = [
        Label(0.12345, 0.3333, 50),
        Label(1.112, 1.5, 66),
        Label(1.6666, 1.78, 70),
        Label(1.6666, 1.8, 73),
        Label(1.94333, 2.3, 65)
    ]
    gt_file_path = "./tests/resource/gt_files/cmedia_gt_file.csv"
    labels = dset.CMediaStructure.load_label(gt_file_path)
    _assert_all_equal(expected, labels)
예제 #3
0
def test_maps_load_label():
    expected = [
        Label(0.5, 1.74117, 69),
        Label(0.843864, 1.13805, 66),
        Label(0.843864, 1.13805, 50),
        Label(0.990956, 1.74117, 73),
        Label(1.13805, 1.74117, 74)
    ]
    gt_file_path = "./tests/resource/gt_files/maps_gt_file.txt"
    labels = dset.MapsStructure.load_label(gt_file_path)
    _assert_all_equal(expected, labels)
예제 #4
0
def test_maestro_load_label():
    expected = [
        Label(31.979545, 32.340909, 60, 0, 80),
        Label(32.4, 32.5, 65, 0, 80),
        Label(32.520455, 33.079545, 67, 0, 80),
        Label(33.059091, 33.479545, 65, 0, 80),
        Label(33.579545, 33.940909, 63, 0, 80)
    ]
    gt_file_path = "./tests/resource/gt_files/maestro_gt_file.mid"
    labels = dset.MaestroStructure.load_label(gt_file_path)
    _assert_all_equal(expected, labels)
예제 #5
0
def test_medleydb_load_label():
    t_unit = 256 / 44100
    expected = [
        Label(2.461315, 2.46712, 74.312369),
        Label(2.46712, 2.472925, 74.283337),
        Label(2.472925, 2.47873, 74.256145),
        Label(2.47873, 2.484535, 74.235337),
        Label(2.484535, 2.4903, 74.209788)
    ]
    gt_file_path = "./tests/resource/gt_files/medleydb_gt_file.csv"
    labels = dset.MedleyDBStructure.load_label(gt_file_path)
    _assert_all_equal(expected, labels)
예제 #6
0
    def load_label(cls, label_path):
        with open(label_path, "r") as lin:
            lines = lin.readlines()

        notes = np.array([round(float(note)) for note in lines])
        note_diff = notes[1:] - notes[:-1]
        change_idx = np.where(note_diff != 0)[0] + 1
        change_idx = np.insert(change_idx, 0,
                               0)  # Padding a single zero to the beginning.
        labels = []
        for idx, chi in enumerate(change_idx[:-1]):
            note = notes[chi]
            if note == 0:
                continue

            start_t = 0.01 * chi + 0.02  # The first frame starts from 20ms.
            end_t = 0.01 * change_idx[idx + 1] + 0.02  # noqa: E226
            if end_t - start_t < 0.05:
                # Minimum duration should over 50ms.
                continue

            labels.append(
                Label(start_time=float(start_t),
                      end_time=float(end_t),
                      note=note))
        return labels
예제 #7
0
    def load_label(cls, label_path):
        labels = []
        sample_rate = 44100
        with open(label_path, "r") as label_file:
            reader = csv.DictReader(label_file, delimiter=",")
            for row in reader:
                onset = float(row["start_time"]) / sample_rate
                offset = float(row["end_time"]) / sample_rate
                inst = int(row["instrument"]) - 1
                note = int(row["note"])

                # The statement used in the paper is 'measure', which is kind of ambiguous.
                start_beat = float(row["start_beat"])

                # It's actually beat length of 'end_beat' column, thus adding start beat position here to
                # make it a 'real end_beat'.
                end_beat = float(row["end_beat"]) + start_beat
                note_value = row["note_value"]

                label = Label(start_time=onset,
                              end_time=offset,
                              note=note,
                              instrument=inst,
                              start_beat=start_beat,
                              end_beat=end_beat,
                              note_value=note_value)
                labels.append(label)
        return labels
예제 #8
0
 def load_label(cls, label_path):
     labels = []
     with open(label_path, "r") as label_file:
         reader = csv.DictReader(label_file, delimiter=",")
         for row in reader:
             labels.append(
                 Label(start_time=float(row["onset"]),
                       end_time=float(row["offset"]),
                       note=int(row["note"])))
     return labels
예제 #9
0
 def load_label(cls, label_path):
     lines = open(label_path, "r").readlines()[
         1:]  # Discard the first line which contains column names
     labels = []
     for line in lines:
         if line.strip() == "":
             continue
         values = line.split("\t")
         onset, offset, note = float(values[0]), float(values[1]), int(
             values[2].strip())
         labels.append(Label(start_time=onset, end_time=offset, note=note))
     return labels
예제 #10
0
    def load_label(cls, label_path):
        with open(label_path, "r") as lin:
            lines = lin.readlines()

        labels = []
        for line in lines[1:]:
            onset, dura, note, _ = line.split(", ")
            labels.append(
                Label(start_time=float(onset),
                      end_time=float(onset) + float(dura),
                      note=round(float(note))))
        return labels
예제 #11
0
def extract_feature_from_midi(midi_path, t_unit=0.01):
    """Extract feature for beat module from MIDI file.

    See Also
    --------
    omnizart.beat.features.extract_feature:
        The main feature extraction function of beat module.
    """
    midi = pretty_midi.PrettyMIDI(midi_path)
    labels = []
    for inst in midi.instruments:
        for note in inst.notes:
            labels.append(Label(start_time=note.start, end_time=note.end, note=note.pitch))
    return extract_feature(labels, t_unit=t_unit)
예제 #12
0
    def load_label(cls, label_path):
        with open(label_path, "r") as lin:
            lines = lin.readlines()

        labels = []
        for idx, line in enumerate(lines):
            note = float(line)
            if note < 0.1:
                # No pitch
                continue
            start_t = 0.01 * idx + 0.02  # The first frame starts from 20ms.
            end_t = start_t + 0.01
            labels.append(Label(start_time=start_t, end_time=end_t, note=note))
        return labels
예제 #13
0
 def load_label(cls, label_path):
     midi = pretty_midi.PrettyMIDI(label_path)
     labels = []
     for inst in midi.instruments:
         if inst.id_drum:
             continue
         for note in inst.notes:
             label = Label(start_time=note.start,
                           end_time=note.end,
                           note=note.pitch,
                           velocity=note.velocity,
                           instrument=inst.program)
             if label.note == -1:
                 continue
             labels.append(label)
     return labels
예제 #14
0
    def load_label(cls, label_path):
        with open(label_path, "r") as fin:
            lines = fin.readlines()

        labels = []
        t_unit = 256 / 44100  # ~= 0.0058 secs
        for line in lines:
            elems = line.strip().split(",")
            sec, hz = float(elems[0]), float(elems[1])  # pylint: disable=invalid-name
            if hz < 1e-10:
                continue
            note = float(
                hz_to_midi(hz))  # Convert return type of np.float64 to float
            end_t = sec + t_unit
            labels.append(Label(start_time=sec, end_time=end_t, note=note))

        return labels
예제 #15
0
def test_mir1k_load_label():
    expected = [
        Label(0.04, 0.06, 57.9108),
        Label(0.06, 0.08, 57.4161),
        Label(0.08, 0.1, 57.174),
        Label(0.1, 0.12, 59.7627),
        Label(0.12, 0.14, 60.0442),
        Label(0.14, 0.16, 60.3304)
    ]
    gt_file_path = "./tests/resource/gt_files/mir1k_gt_file.pv"
    labels = dset.MIR1KStructure.load_label(gt_file_path)
    _assert_all_equal(expected, labels)
예제 #16
0
    def load_label(cls, label_path):
        """Load and parse labels for the given label file path.

        Parses different format of label information to shared intermediate format,
        encapslated with :class:`Label` instances. The default is parsing MIDI
        file format.
        """
        midi = pretty_midi.PrettyMIDI(label_path)
        labels = []
        for inst in midi.instruments:
            if inst.is_drum:
                continue
            for note in inst.notes:
                label = Label(start_time=note.start,
                              end_time=note.end,
                              note=note.pitch,
                              velocity=note.velocity,
                              instrument=inst.program)
                if label.note == -1:
                    continue
                labels.append(label)
        return labels
예제 #17
0
    def load_label(cls, label_path):
        labels = []
        sample_rate = 44100
        with open(label_path, "r") as label_file:
            reader = csv.DictReader(label_file, delimiter=",")
            for row in reader:
                onset = float(row["start_time"]) / sample_rate
                offset = float(row["end_time"]) / sample_rate
                inst = int(row["instrument"]) - 1
                note = int(row["note"])
                start_beat = float(row["start_beat"])
                end_beat = float(row["end_beat"])
                note_value = row["note_value"]

                label = Label(start_time=onset,
                              end_time=offset,
                              note=note,
                              instrument=inst,
                              start_beat=start_beat,
                              end_beat=end_beat,
                              note_value=note_value)
                labels.append(label)
        return labels
예제 #18
0
def test_musicnet_load_label():
    expected = [
        Label(0.231428,
              0.962857,
              61,
              instrument=41,
              start_beat=0,
              end_beat=1.489583,
              note_value="Dotted Half"),
        Label(0.231428,
              0.568117,
              65,
              instrument=6,
              start_beat=0,
              end_beat=0.489583,
              note_value="Quarter"),
        Label(0.231428,
              0.510068,
              46,
              instrument=0,
              start_beat=0,
              end_beat=0.333333,
              note_value="Dotted Eighth"),
        Label(0.579727,
              0.660997,
              63,
              instrument=40,
              start_beat=0.5,
              end_beat=0.739583,
              note_value="Eighth"),
        Label(0.579727,
              0.695827,
              58,
              instrument=70,
              start_beat=0.5,
              end_beat=0.833333,
              note_value="Dotted Eighth"),
        Label(0.672607,
              0.777097,
              65,
              instrument=71,
              start_beat=0.75,
              end_beat=0.989583,
              note_value="Eighth"),
        Label(0.777097,
              1.473696,
              58,
              instrument=73,
              start_beat=1.0,
              end_beat=2.989583,
              note_value="Whole"),
        Label(0.777097,
              0.904807,
              66,
              instrument=68,
              start_beat=1.0,
              end_beat=1.333333,
              note_value="Dotted Eighth"),
        Label(0.962857,
              1.055714,
              63,
              instrument=0,
              start_beat=1.5,
              end_beat=1.739583,
              note_value="Eighth"),
        Label(0.962857,
              1.090566,
              39,
              instrument=43,
              start_beat=1.5,
              end_beat=1.833333,
              note_value="Dotted Eighth")
    ]
    gt_file_path = "./tests/resource/gt_files/musicnet_gt_file.csv"
    labels = dset.MusicNetStructure.load_label(gt_file_path)
    _assert_all_equal(expected, labels)