def testSplitMidi(self): sequence = music_pb2.NoteSequence() sequence.notes.add(pitch=60, start_time=1.0, end_time=2.9) sequence.notes.add(pitch=60, start_time=8.0, end_time=11.0) sequence.notes.add(pitch=60, start_time=14.0, end_time=17.0) sequence.notes.add(pitch=60, start_time=20.0, end_time=23.0) sequence.total_time = 25. sample_rate = 160 samples = np.zeros(sample_rate * int(sequence.total_time)) splits = audio_label_data_utils.find_split_points( sequence, samples, sample_rate, 0, 3) self.assertEqual(splits, [0., 3., 6., 9., 12., 15., 18., 21., 24., 25.]) samples[int(8.5 * sample_rate)] = 1 samples[int(8.5 * sample_rate) + 1] = -1 splits = audio_label_data_utils.find_split_points( sequence, samples, sample_rate, 0, 3) self.assertEqual(splits, [ 0.0, 3.0, 6.0, 8.50625, 11.50625, 14.50625, 17.50625, 20.50625, 23.50625, 25. ])
def split2batch(audio, sequence): from magenta.models.onsets_frames_transcription.audio_label_data_utils import find_split_points pad_num = int(math.ceil( sequence.total_time * cfg.SAMPLE_RATE)) - audio.shape[0] if pad_num > 0: audio = np.concatenate((audio, np.zeros((pad_num), dtype=audio.dtype))) splits = [0, sequence.total_time] if cfg.MAX_SPLIT_LENGTH == 0 else \ find_split_points(sequence, audio, cfg.SAMPLE_RATE, cfg.MIN_SPLIT_LENGTH, cfg.MAX_SPLIT_LENGTH) samples = [] for start, end in zip(splits[:-1], splits[1:]): if end - start < cfg.MIN_SPLIT_LENGTH: continue split_audio, split_seq = audio, sequence if not (start == 0 and end == sequence.total_time): split_seq = sequences_lib.extract_subsequence(sequence, start, end) split_audio = audio_io.crop_samples(audio, cfg.SAMPLE_RATE, start, end - start) pad_num = int(math.ceil( cfg.MAX_SPLIT_LENGTH * cfg.SAMPLE_RATE)) - split_audio.shape[0] if pad_num > 0: split_audio = np.concatenate( (split_audio, np.zeros((pad_num), dtype=split_audio.dtype))) samples.append((split_audio, split_seq)) return samples