def generate_test_set(): """Generate the test TFRecord.""" test_file_pairs = [] for directory in test_dirs: path = os.path.join(FLAGS.input_dir, directory) path = os.path.join(path, '*.wav') wav_files = glob.glob(path) # find matching mid files for wav_file in wav_files: base_name_root, _ = os.path.splitext(wav_file) mid_file = base_name_root + '.mid' test_file_pairs.append((wav_file, mid_file)) test_output_name = os.path.join(FLAGS.output_dir, 'maps_config2_test.tfrecord') with tf.python_io.TFRecordWriter(test_output_name) as writer: for idx, pair in enumerate(test_file_pairs): print('{} of {}: {}'.format(idx, len(test_file_pairs), pair[0])) # load the wav data and resample it. samples = audio_io.load_audio(pair[0], FLAGS.sample_rate) wav_data = audio_io.samples_to_wav_data(samples, FLAGS.sample_rate) # load the midi data and convert to a notesequence ns = midi_io.midi_file_to_note_sequence(pair[1]) example = audio_label_data_utils.create_example( pair[0], ns, wav_data) writer.write(example.SerializeToString()) return [filename_to_id(wav) for wav, _ in test_file_pairs]
def process(self, paths): midi_path, wav_path_base = paths if self._add_wav_glob: wav_paths = tf.io.gfile.glob(wav_path_base + '*') else: wav_paths = [wav_path_base] if midi_path: base_ns = midi_io.midi_file_to_note_sequence(midi_path) base_ns.filename = midi_path else: base_ns = music_pb2.NoteSequence() for wav_path in wav_paths: logging.info('Creating Example %s:%s', midi_path, wav_path) wav_data = tf.io.gfile.GFile(wav_path, 'rb').read() ns = copy.deepcopy(base_ns) # Use base names. ns.id = '%s:%s' % (wav_path.replace( self._wav_dir, ''), midi_path.replace(self._midi_dir, '')) Metrics.counter('create_example', 'read_midi_wav').inc() example = audio_label_data_utils.create_example( ns.id, ns, wav_data) Metrics.counter('create_example', 'created_example').inc() yield example
def predict(self, path: str, wav_data=None): """Using the model, return the predicted note sequence of a .wav file at the given path. Args: path (str): The path to a .wav audio file. If path is "binary", then a binary must be specified. wav_data (bytes): The binary for the .wav file if that is easier to extract. Defaults to None whqen path is provided. Returns: NoteSequence object containing the prediction. Convertable to MIDI. """ if path == "binary": if wav_data is None: raise ValueError( "The binary option is chosen but a binary is not provided." ) else: f = open(path, "rb") wav_data = f.read() f.close() ns = note_seq.NoteSequence() example_list = [ audio_label_data_utils.create_example( path, ns, wav_data, velocity_range=audio_label_data_utils. velocity_range_from_sequence(ns)) ] to_process = [example_list[0].SerializeToString()] print('Processing complete for', path) sess = tf.Session() sess.run([ tf.initializers.global_variables(), tf.initializers.local_variables() ]) sess.run(self.iterator.initializer, {self.examples: to_process}) def transcription_data(params): del params return tf.data.Dataset.from_tensors(sess.run(self.next_record)) input_fn = infer_util.labels_to_features_wrapper(transcription_data) prediction_list = list( self.estimator.predict(input_fn, yield_single_examples=False)) assert len(prediction_list) == 1 sequence_prediction = note_seq.NoteSequence.FromString( prediction_list[0]['sequence_predictions'][0]) return sequence_prediction
def process(self, paths): wav_path, midi_path = paths if midi_path: if FLAGS.use_midi_stems: base_ns = note_sequence_from_directory( os.path.dirname(midi_path)) else: base_ns = midi_io.midi_file_to_note_sequence(midi_path) base_ns.filename = midi_path else: base_ns = music_pb2.NoteSequence() logging.info('Creating Example %s:%s', midi_path, wav_path) if FLAGS.convert_flac: samples, sr = librosa.load(wav_path, FLAGS.sample_rate) wav_data = audio_io.samples_to_wav_data(samples, sr) else: wav_data = tf.io.gfile.GFile(wav_path, 'rb').read() ns = copy.deepcopy(base_ns) # Use base names. ns.id = '%s:%s' % (wav_path, midi_path) Metrics.counter('create_example', 'read_midi_wav').inc() if FLAGS.max_length > 0: split_examples = audio_label_data_utils.process_record( wav_data, ns, ns.id, min_length=FLAGS.min_length, max_length=FLAGS.max_length, sample_rate=FLAGS.sample_rate, load_audio_with_librosa=False) for example in split_examples: Metrics.counter('split_wav', 'split_example').inc() yield example else: example = audio_label_data_utils.create_example( ns.id, ns, wav_data) Metrics.counter('create_example', 'created_example').inc() yield example