Beispiel #1
0
    def testSplitMidi(self):
        sequence = music_pb2.NoteSequence()
        sequence.notes.add(pitch=60, start_time=1.0, end_time=2.9)
        sequence.notes.add(pitch=60, start_time=8.0, end_time=11.0)
        sequence.notes.add(pitch=60, start_time=14.0, end_time=17.0)
        sequence.notes.add(pitch=60, start_time=20.0, end_time=23.0)
        sequence.total_time = 25.

        sample_rate = 160
        samples = np.zeros(sample_rate * int(sequence.total_time))
        splits = create_dataset_util.find_split_points(sequence, samples,
                                                       sample_rate, 0, 3)

        self.assertEqual(splits,
                         [0., 3., 6., 9., 12., 15., 18., 21., 24., 25.])

        samples[int(8.5 * sample_rate)] = 1
        samples[int(8.5 * sample_rate) + 1] = -1
        splits = create_dataset_util.find_split_points(sequence, samples,
                                                       sample_rate, 0, 3)

        self.assertEqual(splits, [
            0.0, 3.0, 6.0, 8.50625, 11.50625, 14.50625, 17.50625, 20.50625,
            23.50625, 25.
        ])
def generate_train_set(exclude_ids):
    """Generate the train TFRecord."""
    train_file_pairs = []
    for directory in TRAIN_DIRS:
        path = os.path.join(FLAGS.input_dir, directory)
        path = os.path.join(path, '*.wav')
        wav_files = glob.glob(path)
        # find matching mid files
        for wav_file in wav_files:
            base_name_root, _ = os.path.splitext(wav_file)
            mid_file = base_name_root + '.mid'
            if filename_to_id(wav_file) not in exclude_ids:
                train_file_pairs.append((wav_file, mid_file))

    train_output_name = os.path.join(FLAGS.output_dir,
                                     'maps_config2_train.tfrecord')

    with tf.python_io.TFRecordWriter(train_output_name) as writer:
        for idx, pair in enumerate(train_file_pairs):
            print("{} of {}: {}".format(idx, len(train_file_pairs), pair[0]))
            # load the wav data
            wav_data = tf.gfile.Open(pair[0], 'rb').read()
            samples = audio_io.wav_data_to_samples(wav_data, FLAGS.sample_rate)
            norm_samples = librosa.util.normalize(samples, norm=np.inf)

            # load the midi data and convert to a notesequence
            ns = midi_io.midi_file_to_note_sequence(pair[1])

            splits = create_dataset_util.find_split_points(
                ns, norm_samples, FLAGS.sample_rate, FLAGS.min_length,
                FLAGS.max_length)

            velocities = [note.velocity for note in ns.notes]
            velocity_max = np.max(velocities)
            velocity_min = np.min(velocities)
            new_velocity_tuple = music_pb2.VelocityRange(min=velocity_min,
                                                         max=velocity_max)

            for start, end in zip(splits[:-1], splits[1:]):
                if end - start < FLAGS.min_length:
                    continue

                new_ns = sequences_lib.extract_subsequence(ns, start, end)
                samples_start = int(start * FLAGS.sample_rate)
                samples_end = samples_start + int(
                    (end - start) * FLAGS.sample_rate)
                new_samples = samples[samples_start:samples_end]
                new_wav_data = audio_io.samples_to_wav_data(
                    new_samples, FLAGS.sample_rate)

                example = tf.train.Example(features=tf.train.Features(
                    feature={
                        'id':
                        tf.train.Feature(bytes_list=tf.train.BytesList(
                            value=[pair[0].encode()])),
                        'sequence':
                        tf.train.Feature(bytes_list=tf.train.BytesList(
                            value=[new_ns.SerializeToString()])),
                        'audio':
                        tf.train.Feature(bytes_list=tf.train.BytesList(
                            value=[new_wav_data])),
                        'velocity_range':
                        tf.train.Feature(bytes_list=tf.train.BytesList(
                            value=[new_velocity_tuple.SerializeToString()])),
                    }))
                writer.write(example.SerializeToString())