def generate_test_set():
    """Generate the test TFRecord."""
    test_file_pairs = []
    for directory in test_dirs:
        path = os.path.join(FLAGS.input_dir, directory)
        path = os.path.join(path, '*.wav')
        wav_files = glob.glob(path)
        # find matching mid files
        for wav_file in wav_files:
            base_name_root, _ = os.path.splitext(wav_file)
            mid_file = base_name_root + '.mid'
            test_file_pairs.append((wav_file, mid_file))

    test_output_name = os.path.join(FLAGS.output_dir,
                                    'maps_config2_test.tfrecord')

    with tf.python_io.TFRecordWriter(test_output_name) as writer:
        for idx, pair in enumerate(test_file_pairs):
            print('{} of {}: {}'.format(idx, len(test_file_pairs), pair[0]))
            # load the wav data and resample it.
            samples = audio_io.load_audio(pair[0], FLAGS.sample_rate)
            wav_data = audio_io.samples_to_wav_data(samples, FLAGS.sample_rate)

            # load the midi data and convert to a notesequence
            ns = midi_io.midi_file_to_note_sequence(pair[1])

            example = audio_label_data_utils.create_example(
                pair[0], ns, wav_data)
            writer.write(example.SerializeToString())

    return [filename_to_id(wav) for wav, _ in test_file_pairs]
    def process(self, paths):
        midi_path, wav_path_base = paths

        if self._add_wav_glob:
            wav_paths = tf.io.gfile.glob(wav_path_base + '*')
        else:
            wav_paths = [wav_path_base]

        if midi_path:
            base_ns = midi_io.midi_file_to_note_sequence(midi_path)
            base_ns.filename = midi_path
        else:
            base_ns = music_pb2.NoteSequence()

        for wav_path in wav_paths:
            logging.info('Creating Example %s:%s', midi_path, wav_path)
            wav_data = tf.io.gfile.GFile(wav_path, 'rb').read()

            ns = copy.deepcopy(base_ns)

            # Use base names.
            ns.id = '%s:%s' % (wav_path.replace(
                self._wav_dir, ''), midi_path.replace(self._midi_dir, ''))

            Metrics.counter('create_example', 'read_midi_wav').inc()

            example = audio_label_data_utils.create_example(
                ns.id, ns, wav_data)

            Metrics.counter('create_example', 'created_example').inc()
            yield example
Esempio n. 3
0
    def predict(self, path: str, wav_data=None):
        """Using the model, return the predicted note sequence of a .wav file at the given path.

        Args:
            path (str): The path to a .wav audio file. If path is "binary", then a binary must be specified.
            wav_data (bytes): The binary for the .wav file if that is easier to extract. Defaults to None whqen path is provided.

        Returns:
            NoteSequence object containing the prediction. Convertable to MIDI.
        """
        if path == "binary":
            if wav_data is None:
                raise ValueError(
                    "The binary option is chosen but a binary is not provided."
                )
        else:
            f = open(path, "rb")
            wav_data = f.read()
            f.close()

        ns = note_seq.NoteSequence()
        example_list = [
            audio_label_data_utils.create_example(
                path,
                ns,
                wav_data,
                velocity_range=audio_label_data_utils.
                velocity_range_from_sequence(ns))
        ]
        to_process = [example_list[0].SerializeToString()]

        print('Processing complete for', path)

        sess = tf.Session()

        sess.run([
            tf.initializers.global_variables(),
            tf.initializers.local_variables()
        ])

        sess.run(self.iterator.initializer, {self.examples: to_process})

        def transcription_data(params):
            del params
            return tf.data.Dataset.from_tensors(sess.run(self.next_record))

        input_fn = infer_util.labels_to_features_wrapper(transcription_data)
        prediction_list = list(
            self.estimator.predict(input_fn, yield_single_examples=False))
        assert len(prediction_list) == 1
        sequence_prediction = note_seq.NoteSequence.FromString(
            prediction_list[0]['sequence_predictions'][0])

        return sequence_prediction
    def process(self, paths):
        wav_path, midi_path = paths

        if midi_path:
            if FLAGS.use_midi_stems:
                base_ns = note_sequence_from_directory(
                    os.path.dirname(midi_path))
            else:
                base_ns = midi_io.midi_file_to_note_sequence(midi_path)
            base_ns.filename = midi_path
        else:
            base_ns = music_pb2.NoteSequence()

        logging.info('Creating Example %s:%s', midi_path, wav_path)
        if FLAGS.convert_flac:
            samples, sr = librosa.load(wav_path, FLAGS.sample_rate)
            wav_data = audio_io.samples_to_wav_data(samples, sr)
        else:
            wav_data = tf.io.gfile.GFile(wav_path, 'rb').read()

        ns = copy.deepcopy(base_ns)

        # Use base names.
        ns.id = '%s:%s' % (wav_path, midi_path)

        Metrics.counter('create_example', 'read_midi_wav').inc()

        if FLAGS.max_length > 0:
            split_examples = audio_label_data_utils.process_record(
                wav_data,
                ns,
                ns.id,
                min_length=FLAGS.min_length,
                max_length=FLAGS.max_length,
                sample_rate=FLAGS.sample_rate,
                load_audio_with_librosa=False)

            for example in split_examples:
                Metrics.counter('split_wav', 'split_example').inc()
                yield example
        else:

            example = audio_label_data_utils.create_example(
                ns.id, ns, wav_data)

            Metrics.counter('create_example', 'created_example').inc()
            yield example