def testSplitMidi(self):
        sequence = music_pb2.NoteSequence()
        sequence.notes.add(pitch=60, start_time=1.0, end_time=2.9)
        sequence.notes.add(pitch=60, start_time=8.0, end_time=11.0)
        sequence.notes.add(pitch=60, start_time=14.0, end_time=17.0)
        sequence.notes.add(pitch=60, start_time=20.0, end_time=23.0)
        sequence.total_time = 25.

        sample_rate = 160
        samples = np.zeros(sample_rate * int(sequence.total_time))
        splits = audio_label_data_utils.find_split_points(
            sequence, samples, sample_rate, 0, 3)

        self.assertEqual(splits,
                         [0., 3., 6., 9., 12., 15., 18., 21., 24., 25.])

        samples[int(8.5 * sample_rate)] = 1
        samples[int(8.5 * sample_rate) + 1] = -1
        splits = audio_label_data_utils.find_split_points(
            sequence, samples, sample_rate, 0, 3)

        self.assertEqual(splits, [
            0.0, 3.0, 6.0, 8.50625, 11.50625, 14.50625, 17.50625, 20.50625,
            23.50625, 25.
        ])
Esempio n. 2
0
def split2batch(audio, sequence):
    from magenta.models.onsets_frames_transcription.audio_label_data_utils import find_split_points
    pad_num = int(math.ceil(
        sequence.total_time * cfg.SAMPLE_RATE)) - audio.shape[0]
    if pad_num > 0:
        audio = np.concatenate((audio, np.zeros((pad_num), dtype=audio.dtype)))

    splits = [0, sequence.total_time] if cfg.MAX_SPLIT_LENGTH == 0 else \
        find_split_points(sequence, audio, cfg.SAMPLE_RATE, cfg.MIN_SPLIT_LENGTH, cfg.MAX_SPLIT_LENGTH)

    samples = []
    for start, end in zip(splits[:-1], splits[1:]):
        if end - start < cfg.MIN_SPLIT_LENGTH:
            continue

        split_audio, split_seq = audio, sequence
        if not (start == 0 and end == sequence.total_time):
            split_seq = sequences_lib.extract_subsequence(sequence, start, end)
        split_audio = audio_io.crop_samples(audio, cfg.SAMPLE_RATE, start,
                                            end - start)
        pad_num = int(math.ceil(
            cfg.MAX_SPLIT_LENGTH * cfg.SAMPLE_RATE)) - split_audio.shape[0]
        if pad_num > 0:
            split_audio = np.concatenate(
                (split_audio, np.zeros((pad_num), dtype=split_audio.dtype)))

        samples.append((split_audio, split_seq))

    return samples