def _CreateExamplesAndExpectedInputs(self, truncated_length, lengths, expected_num_inputs): hparams = copy.deepcopy(configs.DEFAULT_HPARAMS) examples = [] expected_inputs = [] for i, length in enumerate(lengths): wav_samples = np.zeros( (np.int((length / data.hparams_frames_per_second(hparams)) * hparams.sample_rate), 1), np.float32) wav_data = audio_io.samples_to_wav_data(wav_samples, hparams.sample_rate) num_frames = data.wav_to_num_frames( wav_data, frames_per_second=data.hparams_frames_per_second(hparams)) seq = self._SyntheticSequence( num_frames / data.hparams_frames_per_second(hparams), i + constants.MIN_MIDI_PITCH) examples.append(self._FillExample(seq, wav_data, 'ex%d' % i)) expected_inputs += self._ExampleToInputs( examples[-1], truncated_length) self.assertEqual(expected_num_inputs, len(expected_inputs)) return examples, expected_inputs
def predict_sequence(frame_predictions, onset_predictions, offset_predictions, velocity_values, min_pitch, hparams): """Predict sequence given model output.""" if not hparams.predict_onset_threshold: onset_predictions = None if not hparams.predict_offset_threshold: offset_predictions = None if hparams.onset_only_sequence_prediction: if not onset_predictions: raise ValueError( 'Cannot do onset only prediction if onsets are not defined.') sequence_prediction = sequences_lib.pianoroll_onsets_to_note_sequence( onsets=onset_predictions, frames_per_second=data.hparams_frames_per_second(hparams), note_duration_seconds=0.05, min_midi_pitch=min_pitch, velocity_values=velocity_values) else: sequence_prediction = sequences_lib.pianoroll_to_note_sequence( frames=frame_predictions, frames_per_second=data.hparams_frames_per_second(hparams), min_duration_ms=0, min_midi_pitch=min_pitch, onset_predictions=onset_predictions, offset_predictions=offset_predictions, velocity_values=velocity_values) return sequence_prediction
def _note_metrics(labels, predictions): """A pyfunc that wraps a call to precision_recall_f1_overlap.""" est_sequence = pianoroll_to_note_sequence( predictions, frames_per_second=data.hparams_frames_per_second(hparams), min_duration_ms=hparams.min_duration_ms) ref_sequence = pianoroll_to_note_sequence( labels, frames_per_second=data.hparams_frames_per_second(hparams), min_duration_ms=hparams.min_duration_ms) est_intervals, est_pitches = sequence_to_valued_intervals( est_sequence, hparams.min_duration_ms) ref_intervals, ref_pitches = sequence_to_valued_intervals( ref_sequence, hparams.min_duration_ms) if est_intervals.size == 0 or ref_intervals.size == 0: return 0., 0., 0. note_precision, note_recall, note_f1, _ = precision_recall_f1_overlap( ref_intervals, pretty_midi.note_number_to_hz(ref_pitches), est_intervals, pretty_midi.note_number_to_hz(est_pitches), offset_ratio=offset_ratio) return note_precision, note_recall, note_f1
def _ExampleToInputs(self, ex, truncated_length=0): hparams = copy.deepcopy(configs.DEFAULT_HPARAMS) filename = ex.features.feature['id'].bytes_list.value[0] sequence = music_pb2.NoteSequence.FromString( ex.features.feature['sequence'].bytes_list.value[0]) wav_data = ex.features.feature['audio'].bytes_list.value[0] spec = data.wav_to_spec(wav_data, hparams=hparams) roll = sequences_lib.sequence_to_pianoroll( sequence, frames_per_second=data.hparams_frames_per_second(hparams), min_pitch=constants.MIN_MIDI_PITCH, max_pitch=constants.MAX_MIDI_PITCH, min_frame_occupancy_for_label=0.0, onset_mode='length_ms', onset_length_ms=32., onset_delay_ms=0.) length = data.wav_to_num_frames( wav_data, frames_per_second=data.hparams_frames_per_second(hparams)) return self._DataToInputs(spec, roll.active, roll.weights, length, filename, truncated_length)
def _ExampleToInputs(self, ex, truncated_length=0, crop_training_sequence_to_notes=False): hparams = copy.deepcopy(constants.DEFAULT_HPARAMS) hparams.crop_training_sequence_to_notes = crop_training_sequence_to_notes filename = ex.features.feature['id'].bytes_list.value[0] sequence, crop_beginning_seconds = data.preprocess_sequence( ex.features.feature['sequence'].bytes_list.value[0], hparams) wav_data = ex.features.feature['audio'].bytes_list.value[0] if crop_training_sequence_to_notes: wav_data = audio_io.crop_wav_data(wav_data, hparams.sample_rate, crop_beginning_seconds, sequence.total_time) spec = data.wav_to_spec(wav_data, hparams=hparams) roll = sequences_lib.sequence_to_pianoroll( sequence, frames_per_second=data.hparams_frames_per_second(hparams), min_pitch=constants.MIN_MIDI_PITCH, max_pitch=constants.MAX_MIDI_PITCH, min_frame_occupancy_for_label=0.0, onset_mode='length_ms', onset_length_ms=32., onset_delay_ms=0.) length = data.wav_to_num_frames( wav_data, frames_per_second=data.hparams_frames_per_second(hparams)) return self._DataToInputs(spec, roll.active, roll.weights, length, filename, truncated_length)
def _ValidateProvideBatchMemory(self, truncated_length, batch_size, lengths, expected_num_inputs): hparams = copy.deepcopy(constants.DEFAULT_HPARAMS) examples = [] expected_inputs = [] for i, length in enumerate(lengths): wav_samples = np.zeros( (np.int((length / data.hparams_frames_per_second(hparams)) * hparams.sample_rate), 1), np.float32) wav_data = audio_io.samples_to_wav_data(wav_samples, hparams.sample_rate) num_frames = data.wav_to_num_frames( wav_data, frames_per_second=data.hparams_frames_per_second(hparams)) seq = self._SyntheticSequence( num_frames / data.hparams_frames_per_second(hparams), i + constants.MIN_MIDI_PITCH) examples.append(self._FillExample(seq, wav_data, 'ex%d' % i)) expected_inputs += self._ExampleToInputs( examples[-1], truncated_length) self.assertEqual(expected_num_inputs, len(expected_inputs)) self._ValidateProvideBatch( [e.SerializeToString() for e in examples], truncated_length, batch_size, expected_inputs)
def validateProvideBatch_TFRecord(self, truncated_length, batch_size, lengths, expected_num_inputs): hparams = copy.deepcopy(constants.DEFAULT_HPARAMS) examples = [] expected_inputs = [] for i, length in enumerate(lengths): wav_samples = np.zeros( (np.int((length / data.hparams_frames_per_second(hparams)) * constants.DEFAULT_SAMPLE_RATE), 1), np.float32) wav_data = audio_io.samples_to_wav_data( wav_samples, constants.DEFAULT_SAMPLE_RATE) num_frames = data.wav_to_num_frames( wav_data, frames_per_second=data.hparams_frames_per_second(hparams)) seq = self._SyntheticSequence( num_frames / data.hparams_frames_per_second(hparams), i + constants.MIN_MIDI_PITCH) examples.append(self._FillExample(seq, wav_data, 'ex%d' % i)) expected_inputs += self._ExampleToInputs(examples[-1], truncated_length) self.assertEqual(expected_num_inputs, len(expected_inputs)) with tempfile.NamedTemporaryFile() as temp_rio: with tf.python_io.TFRecordWriter(temp_rio.name) as writer: for ex in examples: writer.write(ex.SerializeToString()) self.validateProvideBatch(temp_rio.name, truncated_length, batch_size, expected_inputs)
def _ValidateProvideBatchMemory(self, truncated_length, batch_size, lengths, expected_num_inputs): hparams = copy.deepcopy(constants.DEFAULT_HPARAMS) examples = [] expected_inputs = [] for i, length in enumerate(lengths): wav_samples = np.zeros( (np.int((length / data.hparams_frames_per_second(hparams)) * hparams.sample_rate), 1), np.float32) wav_data = audio_io.samples_to_wav_data(wav_samples, hparams.sample_rate) num_frames = data.wav_to_num_frames( wav_data, frames_per_second=data.hparams_frames_per_second(hparams)) seq = self._SyntheticSequence( num_frames / data.hparams_frames_per_second(hparams), i + constants.MIN_MIDI_PITCH) examples.append(self._FillExample(seq, wav_data, 'ex%d' % i)) expected_inputs += self._ExampleToInputs(examples[-1], truncated_length) self.assertEqual(expected_num_inputs, len(expected_inputs)) self._ValidateProvideBatch([e.SerializeToString() for e in examples], truncated_length, batch_size, expected_inputs)
def _note_metrics(labels, predictions): """A pyfunc that wraps a call to precision_recall_f1_overlap.""" est_sequence = sequences_lib.pianoroll_to_note_sequence( predictions, frames_per_second=data.hparams_frames_per_second(hparams), min_duration_ms=hparams.min_duration_ms) ref_sequence = sequences_lib.pianoroll_to_note_sequence( labels, frames_per_second=data.hparams_frames_per_second(hparams), min_duration_ms=hparams.min_duration_ms) est_intervals, est_pitches, _ = infer_util.sequence_to_valued_intervals( est_sequence, hparams.min_duration_ms) ref_intervals, ref_pitches, _ = infer_util.sequence_to_valued_intervals( ref_sequence, hparams.min_duration_ms) if est_intervals.size == 0 or ref_intervals.size == 0: return 0., 0., 0. note_precision, note_recall, note_f1, _ = precision_recall_f1_overlap( ref_intervals, pretty_midi.note_number_to_hz(ref_pitches), est_intervals, pretty_midi.note_number_to_hz(est_pitches), offset_ratio=offset_ratio) return note_precision, note_recall, note_f1
def predict_sequence(frame_probs, onset_probs, frame_predictions, onset_predictions, offset_predictions, velocity_values, min_pitch, hparams, onsets_only=False): """Predict sequence given model output.""" if not hparams.predict_onset_threshold: onset_predictions = None if not hparams.predict_offset_threshold: offset_predictions = None if onsets_only: if onset_predictions is None: raise ValueError( 'Cannot do onset only prediction if onsets are not defined.') sequence_prediction = sequences_lib.pianoroll_onsets_to_note_sequence( onsets=onset_predictions, frames_per_second=data.hparams_frames_per_second(hparams), note_duration_seconds=0.05, min_midi_pitch=min_pitch, velocity_values=velocity_values, velocity_scale=hparams.velocity_scale, velocity_bias=hparams.velocity_bias) else: if hparams.viterbi_decoding: pianoroll = probs_to_pianoroll_viterbi(frame_probs, onset_probs, alpha=hparams.viterbi_alpha) onsets = np.concatenate( [pianoroll[:1, :], pianoroll[1:, :] & ~pianoroll[:-1, :]], axis=0) sequence_prediction = sequences_lib.pianoroll_to_note_sequence( frames=pianoroll, frames_per_second=data.hparams_frames_per_second(hparams), min_duration_ms=0, min_midi_pitch=min_pitch, onset_predictions=onsets, velocity_values=velocity_values, velocity_scale=hparams.velocity_scale, velocity_bias=hparams.velocity_bias) else: sequence_prediction = sequences_lib.pianoroll_to_note_sequence( frames=frame_predictions, frames_per_second=data.hparams_frames_per_second(hparams), min_duration_ms=0, min_midi_pitch=min_pitch, onset_predictions=onset_predictions, offset_predictions=offset_predictions, velocity_values=velocity_values, velocity_scale=hparams.velocity_scale, velocity_bias=hparams.velocity_bias) return sequence_prediction
def transcribe_audio(transcription_session, filename, frame_threshold, onset_threshold): """Transcribes an audio file.""" tf.logging.info('Processing file...') transcription_session.session.run( transcription_session.iterator.initializer, { transcription_session.examples: [create_example(filename, transcription_session.hparams)] }) tf.logging.info('Running inference...') frame_logits, onset_logits, velocity_values = ( transcription_session.session.run([ transcription_session.frame_probs_flat, transcription_session.onset_probs_flat, transcription_session.velocity_values_flat ])) frame_predictions = frame_logits > frame_threshold onset_predictions = onset_logits > onset_threshold sequence_prediction = sequences_lib.pianoroll_to_note_sequence( frame_predictions, frames_per_second=data.hparams_frames_per_second( transcription_session.hparams), min_duration_ms=0, onset_predictions=onset_predictions, velocity_values=velocity_values) for note in sequence_prediction.notes: note.pitch += constants.MIN_MIDI_PITCH return sequence_prediction
def transcribe_audio(transcription_session, filename, frame_threshold, onset_threshold): """Transcribes an audio file.""" tf.logging.info('Processing file...') transcription_session.session.run( transcription_session.iterator.initializer, {transcription_session.examples: [ create_example(filename, transcription_session.hparams)]}) tf.logging.info('Running inference...') frame_logits, onset_logits, velocity_values = ( transcription_session.session.run([ transcription_session.frame_probs_flat, transcription_session.onset_probs_flat, transcription_session.velocity_values_flat])) frame_predictions = frame_logits > frame_threshold onset_predictions = onset_logits > onset_threshold sequence_prediction = sequences_lib.pianoroll_to_note_sequence( frame_predictions, frames_per_second=data.hparams_frames_per_second( transcription_session.hparams), min_duration_ms=0, onset_predictions=onset_predictions, velocity_values=velocity_values) for note in sequence_prediction.notes: note.pitch += constants.MIN_MIDI_PITCH return sequence_prediction
def _ValidateProvideBatchTFRecord(self, truncated_length, batch_size, lengths, expected_num_inputs, crop_sequence_secs=0): hparams = copy.deepcopy(constants.DEFAULT_HPARAMS) examples = [] expected_inputs = [] for i, length in enumerate(lengths): wav_samples = np.zeros( (np.int((length / data.hparams_frames_per_second(hparams)) * hparams.sample_rate), 1), np.float32) wav_data = audio_io.samples_to_wav_data(wav_samples, hparams.sample_rate) num_frames = data.wav_to_num_frames( wav_data, frames_per_second=data.hparams_frames_per_second(hparams)) seq = self._SyntheticSequence( num_frames / data.hparams_frames_per_second(hparams) - crop_sequence_secs * 2, # crop from both ends. i + constants.MIN_MIDI_PITCH, start_time=crop_sequence_secs) examples.append(self._FillExample(seq, wav_data, 'ex%d' % i)) expected_inputs += self._ExampleToInputs( examples[-1], truncated_length, crop_training_sequence_to_notes=crop_sequence_secs > 0) self.assertEqual(expected_num_inputs, len(expected_inputs)) with tempfile.NamedTemporaryFile() as temp_tfr: with tf.python_io.TFRecordWriter(temp_tfr.name) as writer: for ex in examples: writer.write(ex.SerializeToString()) self._ValidateProvideBatch( temp_tfr.name, truncated_length, batch_size, expected_inputs, crop_training_sequence_to_notes=crop_sequence_secs > 0)
def infer(filename): # WAV 파일 Binary로 읽기 wav = open(filename, 'rb') wav_data = wav.read() wav.close() tf.logging.info('User .WAV FIle %s length %s bytes', filename, len(wav_data)) ## 전처리 # 청크로 분할 후, Protocol Buffers 로 변환 to_process = [] examples = list( audio_label_data_utils.process_record(wav_data=wav_data, ns=music_pb2.NoteSequence(), example_id=filename, min_length=0, max_length=-1, allow_empty_notesequence=True)) # 분할된 버퍼를 시리얼라이즈 to_process.append(examples[0].SerializeToString()) ############################################################# #시리얼라이즈한 버퍼를 iterator에 주입 sess.run(iterator.initializer, {example: to_process}) # Inference predictions = list(estimator.predict(input_fn, yield_single_examples=False)) #가정 설정문으로 prediction size를 1로 보장 assert len(predictions) == 1 #예측 결과 불러오기 frame_predictions = predictions[0]['frame_predictions'][0] onset_predictions = predictions[0]['onset_predictions'][0] # 치는 순간 velocity_values = predictions[0]['velocity_values'][0] #강약 #MIDI로 인코딩 sequence_prediction = sequences_lib.pianoroll_to_note_sequence( frame_predictions, frames_per_second=data.hparams_frames_per_second(hparams), min_duration_ms=0, min_midi_pitch=constants.MIN_MIDI_PITCH, onset_predictions=onset_predictions, velocity_values=velocity_values) basename = os.path.split(os.path.splitext(filename)[0])[1] + '.mid' output_filename = os.path.join('', basename) midi_filename = (output_filename) midi_io.sequence_proto_to_midi_file(sequence_prediction, midi_filename) print('Program Ended, Your MIDI File is in', output_filename) sess.close()
def _ValidateProvideBatch(self, examples, truncated_length, batch_size, expected_inputs, feed_dict=None): """Tests for correctness of batches.""" hparams = copy.deepcopy(configs.DEFAULT_HPARAMS) hparams.batch_size = batch_size hparams.truncated_length_secs = ( truncated_length / data.hparams_frames_per_second(hparams)) with self.test_session() as sess: dataset = data.provide_batch(examples=examples, preprocess_examples=True, params=hparams, is_training=False, shuffle_examples=False, skip_n_initial_records=0) iterator = dataset.make_initializable_iterator() next_record = iterator.get_next() sess.run([ tf.initializers.local_variables(), tf.initializers.global_variables(), iterator.initializer ], feed_dict=feed_dict) for i in range(0, len(expected_inputs), batch_size): # Wait to ensure example is pre-processed. time.sleep(0.1) features, labels = sess.run(next_record) inputs = [ features.spec, labels.labels, features.length, features.sequence_id ] max_length = np.max(inputs[2]) for j in range(batch_size): # Add batch padding if needed. input_length = expected_inputs[i + j][2] if input_length < max_length: expected_inputs[i + j] = list(expected_inputs[i + j]) pad_amt = max_length - input_length expected_inputs[i + j][0] = np.pad( expected_inputs[i + j][0], [(0, pad_amt), (0, 0)], 'constant') expected_inputs[i + j][1] = np.pad( expected_inputs[i + j][1], [(0, pad_amt), (0, 0)], 'constant') for exp_input, input_ in zip(expected_inputs[i + j], inputs): self.assertAllEqual(np.squeeze(exp_input), np.squeeze(input_[j])) with self.assertRaisesOpError('End of sequence'): _ = sess.run(next_record)
def validateProvideBatch_TFRecord(self, truncated_length, batch_size, lengths, expected_num_inputs): hparams = copy.deepcopy(constants.DEFAULT_HPARAMS) examples = [] expected_inputs = [] for i, length in enumerate(lengths): wav_samples = np.zeros( (np.int((length / data.hparams_frames_per_second(hparams)) * constants.DEFAULT_SAMPLE_RATE), 1), np.float32) wav_data = audio_io.samples_to_wav_data(wav_samples, constants.DEFAULT_SAMPLE_RATE) num_frames = data.wav_to_num_frames( wav_data, frames_per_second=data.hparams_frames_per_second(hparams)) seq = self._SyntheticSequence( num_frames / data.hparams_frames_per_second(hparams), i + constants.MIN_MIDI_PITCH) examples.append(self._FillExample(seq, wav_data, 'ex%d' % i)) expected_inputs += self._ExampleToInputs( examples[-1], truncated_length) self.assertEqual(expected_num_inputs, len(expected_inputs)) with tempfile.NamedTemporaryFile() as temp_rio: with tf.python_io.TFRecordWriter(temp_rio.name) as writer: for ex in examples: writer.write(ex.SerializeToString()) self.validateProvideBatch( temp_rio.name, truncated_length, batch_size, expected_inputs)
def transcribe_audio(prediction, hparams): """Transcribes an audio file.""" frame_predictions = prediction['frame_predictions'] onset_predictions = prediction['onset_predictions'] velocity_values = prediction['velocity_values'] sequence_prediction = sequences_lib.pianoroll_to_note_sequence( frame_predictions, frames_per_second=data.hparams_frames_per_second(hparams), min_duration_ms=0, min_midi_pitch=constants.MIN_MIDI_PITCH, onset_predictions=onset_predictions, velocity_values=velocity_values) return sequence_prediction
def _ValidateProvideBatch(self, examples, truncated_length, batch_size, expected_inputs, feed_dict=None): """Tests for correctness of batches.""" hparams = copy.deepcopy(configs.DEFAULT_HPARAMS) hparams.batch_size = batch_size hparams.truncated_length_secs = ( truncated_length / data.hparams_frames_per_second(hparams)) with self.test_session() as sess: dataset = data.provide_batch( examples=examples, preprocess_examples=True, hparams=hparams, is_training=False) iterator = dataset.make_initializable_iterator() next_record = iterator.get_next() sess.run([ tf.initializers.local_variables(), tf.initializers.global_variables(), iterator.initializer ], feed_dict=feed_dict) for i in range(0, len(expected_inputs), batch_size): # Wait to ensure example is pre-processed. time.sleep(0.1) features, labels = sess.run(next_record) inputs = [ features.spec, labels.labels, features.length, features.sequence_id] max_length = np.max(inputs[2]) for j in range(batch_size): # Add batch padding if needed. input_length = expected_inputs[i + j][2] if input_length < max_length: expected_inputs[i + j] = list(expected_inputs[i + j]) pad_amt = max_length - input_length expected_inputs[i + j][0] = np.pad( expected_inputs[i + j][0], [(0, pad_amt), (0, 0)], 'constant') expected_inputs[i + j][1] = np.pad( expected_inputs[i + j][1], [(0, pad_amt), (0, 0)], 'constant') for exp_input, input_ in zip(expected_inputs[i + j], inputs): self.assertAllEqual(np.squeeze(exp_input), np.squeeze(input_[j])) with self.assertRaisesOpError('End of sequence'): _ = sess.run(next_record)
def inference(filename): # 오디오 파일(.wav) 읽기 wav_file = open(filename, mode='rb') wav_data = wav_file.read() wav_file.close() print('User uploaded file "{name}" with length {length} bytes'.format(name=filename, length=len(wav_data))) # 청크로 분할 후 protobufs 포맷으로 데이터 생성 to_process = [] example_list = list( audio_label_data_utils.process_record(wav_data=wav_data, ns=music_pb2.NoteSequence(), example_id=filename, min_length=0, max_length=-1, allow_empty_notesequence=True)) # Serialize to_process.append(example_list[0].SerializeToString()) # 세션 실행 sess.run(iterator.initializer, {examples: to_process}) # 예측 prediction_list = list(estimator.predict(input_fn, yield_single_examples=False)) assert len(prediction_list) == 1 # 예측 결과 데이터 가져오기 frame_predictions = prediction_list[0]['frame_predictions'][0] onset_predictions = prediction_list[0]['onset_predictions'][0] velocity_values = prediction_list[0]['velocity_values'][0] # 예측 결과 데이터를 이용해서 미디 시퀀스 생성 sequence_prediction = sequences_lib.pianoroll_to_note_sequence( frame_predictions, frames_per_second=data.hparams_frames_per_second(hparams), min_duration_ms=0, min_midi_pitch=constants.MIN_MIDI_PITCH, onset_predictions=onset_predictions, velocity_values=velocity_values) basename = os.path.split(os.path.splitext(filename)[0])[1] + '.mid' output_filename = os.path.join(env.MIDI_DIRECTORY, basename) # 미디 시퀀스를 파일로 내보내기 midi_filename = (output_filename) midi_io.sequence_proto_to_midi_file(sequence_prediction, midi_filename) return basename
def predict_sequence(frame_predictions, onset_predictions, offset_predictions, velocity_values, min_pitch, hparams): """Predict sequence given model output.""" if not hparams.predict_onset_threshold: onset_predictions = None if not hparams.predict_offset_threshold: offset_predictions = None sequence_prediction = sequences_lib.pianoroll_to_note_sequence( frames=frame_predictions, frames_per_second=data.hparams_frames_per_second(hparams), min_duration_ms=0, min_midi_pitch=min_pitch, onset_predictions=onset_predictions, offset_predictions=offset_predictions, velocity_values=velocity_values) return sequence_prediction
def model_inference(acoustic_checkpoint, hparams, examples_path, run_dir): """Runs inference for the given examples.""" tf.logging.info('acoustic_checkpoint=%s', acoustic_checkpoint) tf.logging.info('examples_path=%s', examples_path) tf.logging.info('run_dir=%s', run_dir) with tf.Graph().as_default(): num_dims = constants.MIDI_PITCHES # Build the acoustic model within an 'acoustic' scope to isolate its # variables from the other models. with tf.variable_scope('acoustic'): truncated_length = 0 if FLAGS.max_seconds_per_sequence: truncated_length = int( math.ceil((FLAGS.max_seconds_per_sequence * data.hparams_frames_per_second(hparams)))) acoustic_data_provider, _ = data.provide_batch( batch_size=1, examples=examples_path, hparams=hparams, is_training=False, truncated_length=truncated_length, include_note_sequences=True) _, _, data_labels, _, _ = model.get_model( acoustic_data_provider, hparams, is_training=False) # The checkpoints won't have the new scopes. acoustic_variables = { re.sub(r'^acoustic/', '', var.op.name): var for var in slim.get_variables(scope='acoustic/') } acoustic_restore = tf.train.Saver(acoustic_variables) onset_probs_flat = tf.get_default_graph().get_tensor_by_name( 'acoustic/onsets/onset_probs_flat:0') frame_probs_flat = tf.get_default_graph().get_tensor_by_name( 'acoustic/frame_probs_flat:0') offset_probs_flat = tf.get_default_graph().get_tensor_by_name( 'acoustic/offsets/offset_probs_flat:0') velocity_values_flat = tf.get_default_graph().get_tensor_by_name( 'acoustic/velocity/velocity_values_flat:0') # Define some metrics. (metrics_to_updates, metric_note_precision, metric_note_recall, metric_note_f1, metric_note_precision_with_offsets, metric_note_recall_with_offsets, metric_note_f1_with_offsets, metric_note_precision_with_offsets_velocity, metric_note_recall_with_offsets_velocity, metric_note_f1_with_offsets_velocity, metric_frame_labels, metric_frame_predictions) = infer_util.define_metrics(num_dims) summary_op = tf.summary.merge_all() global_step = tf.contrib.framework.get_or_create_global_step() global_step_increment = global_step.assign_add(1) # Use a custom init function to restore the acoustic and language models # from their separate checkpoints. def init_fn(unused_self, sess): acoustic_restore.restore(sess, acoustic_checkpoint) scaffold = tf.train.Scaffold(init_fn=init_fn) session_creator = tf.train.ChiefSessionCreator( scaffold=scaffold, master=FLAGS.master) with tf.train.MonitoredSession(session_creator=session_creator) as sess: tf.logging.info('running session') summary_writer = tf.summary.FileWriter( logdir=run_dir, graph=sess.graph) tf.logging.info('Inferring for %d batches', acoustic_data_provider.num_batches) infer_times = [] num_frames = [] for unused_i in range(acoustic_data_provider.num_batches): start_time = time.time() (labels, filenames, note_sequences, frame_probs, onset_probs, offset_probs, velocity_values) = sess.run([ data_labels, acoustic_data_provider.filenames, acoustic_data_provider.note_sequences, frame_probs_flat, onset_probs_flat, offset_probs_flat, velocity_values_flat, ]) # We expect these all to be length 1 because batch size is 1. assert len(filenames) == len(note_sequences) == 1 # These should be the same length and have been flattened. assert len(labels) == len(frame_probs) == len(onset_probs) frame_predictions = frame_probs > FLAGS.frame_threshold if FLAGS.require_onset: onset_predictions = onset_probs > FLAGS.onset_threshold else: onset_predictions = None if FLAGS.use_offset: offset_predictions = offset_probs > FLAGS.offset_threshold else: offset_predictions = None sequence_prediction = sequences_lib.pianoroll_to_note_sequence( frame_predictions, frames_per_second=data.hparams_frames_per_second(hparams), min_duration_ms=0, min_midi_pitch=constants.MIN_MIDI_PITCH, onset_predictions=onset_predictions, offset_predictions=offset_predictions, velocity_values=velocity_values) end_time = time.time() infer_time = end_time - start_time infer_times.append(infer_time) num_frames.append(frame_probs.shape[0]) tf.logging.info( 'Infer time %f, frames %d, frames/sec %f, running average %f', infer_time, frame_probs.shape[0], frame_probs.shape[0] / infer_time, np.sum(num_frames) / np.sum(infer_times)) tf.logging.info('Scoring sequence %s', filenames[0]) def shift_notesequence(ns_time): return ns_time + hparams.backward_shift_amount_ms / 1000. sequence_label = infer_util.score_sequence( sess, global_step_increment, summary_op, summary_writer, metrics_to_updates, metric_note_precision, metric_note_recall, metric_note_f1, metric_note_precision_with_offsets, metric_note_recall_with_offsets, metric_note_f1_with_offsets, metric_note_precision_with_offsets_velocity, metric_note_recall_with_offsets_velocity, metric_note_f1_with_offsets_velocity, metric_frame_labels, metric_frame_predictions, frame_labels=labels, sequence_prediction=sequence_prediction, frames_per_second=data.hparams_frames_per_second(hparams), sequence_label=sequences_lib.adjust_notesequence_times( music_pb2.NoteSequence.FromString(note_sequences[0]), shift_notesequence)[0], sequence_id=filenames[0]) # Make filenames UNIX-friendly. filename = filenames[0].decode('utf-8').replace('/', '_').replace( ':', '.') output_file = os.path.join(run_dir, filename + '.mid') tf.logging.info('Writing inferred midi file to %s', output_file) midi_io.sequence_proto_to_midi_file(sequence_prediction, output_file) label_output_file = os.path.join(run_dir, filename + '_label.mid') tf.logging.info('Writing label midi file to %s', label_output_file) midi_io.sequence_proto_to_midi_file(sequence_label, label_output_file) # Also write a pianoroll showing acoustic model output vs labels. pianoroll_output_file = os.path.join(run_dir, filename + '_pianoroll.png') tf.logging.info('Writing acoustic logit/label file to %s', pianoroll_output_file) with tf.gfile.GFile(pianoroll_output_file, mode='w') as f: scipy.misc.imsave( f, infer_util.posterior_pianoroll_image( frame_probs, sequence_prediction, labels, overlap=True, frames_per_second=data.hparams_frames_per_second(hparams))) summary_writer.flush()
def _calculate_metrics_py(frame_predictions, onset_predictions, offset_predictions, velocity_values, sequence_label_str, frame_labels, sequence_id, hparams): """Python logic for calculating metrics on a single example.""" tf.logging.info('Calculating metrics for %s with length %d', sequence_id, frame_labels.shape[0]) if not hparams.predict_onset_threshold: onset_predictions = None if not hparams.predict_offset_threshold: offset_predictions = None sequence_prediction = sequences_lib.pianoroll_to_note_sequence( frames=frame_predictions, frames_per_second=data.hparams_frames_per_second(hparams), min_duration_ms=0, min_midi_pitch=constants.MIN_MIDI_PITCH, onset_predictions=onset_predictions, offset_predictions=offset_predictions, velocity_values=velocity_values) sequence_label = music_pb2.NoteSequence.FromString(sequence_label_str) if hparams.backward_shift_amount_ms: def shift_notesequence(ns_time): return ns_time + hparams.backward_shift_amount_ms / 1000. shifted_sequence_label, skipped_notes = ( sequences_lib.adjust_notesequence_times(sequence_label, shift_notesequence)) assert skipped_notes == 0 sequence_label = shifted_sequence_label est_intervals, est_pitches, est_velocities = ( infer_util.sequence_to_valued_intervals(sequence_prediction)) ref_intervals, ref_pitches, ref_velocities = ( infer_util.sequence_to_valued_intervals(sequence_label)) note_precision, note_recall, note_f1, _ = ( mir_eval.transcription.precision_recall_f1_overlap( ref_intervals, pretty_midi.note_number_to_hz(ref_pitches), est_intervals, pretty_midi.note_number_to_hz(est_pitches), offset_ratio=None)) (note_with_offsets_precision, note_with_offsets_recall, note_with_offsets_f1, _) = (mir_eval.transcription.precision_recall_f1_overlap( ref_intervals, pretty_midi.note_number_to_hz(ref_pitches), est_intervals, pretty_midi.note_number_to_hz(est_pitches))) (note_with_offsets_velocity_precision, note_with_offsets_velocity_recall, note_with_offsets_velocity_f1, _) = (mir_eval.transcription_velocity.precision_recall_f1_overlap( ref_intervals=ref_intervals, ref_pitches=pretty_midi.note_number_to_hz(ref_pitches), ref_velocities=ref_velocities, est_intervals=est_intervals, est_pitches=pretty_midi.note_number_to_hz(est_pitches), est_velocities=est_velocities)) processed_frame_predictions = sequences_lib.sequence_to_pianoroll( sequence_prediction, frames_per_second=data.hparams_frames_per_second(hparams), min_pitch=constants.MIN_MIDI_PITCH, max_pitch=constants.MAX_MIDI_PITCH).active if processed_frame_predictions.shape[0] < frame_labels.shape[0]: # Pad transcribed frames with silence. pad_length = frame_labels.shape[0] - processed_frame_predictions.shape[ 0] processed_frame_predictions = np.pad(processed_frame_predictions, [(0, pad_length), (0, 0)], 'constant') elif processed_frame_predictions.shape[0] > frame_labels.shape[0]: # Truncate transcribed frames. processed_frame_predictions = ( processed_frame_predictions[:frame_labels.shape[0], :]) tf.logging.info( 'Metrics for %s: Note F1 %f, Note w/ offsets F1 %f, ' 'Note w/ offsets & velocity: %f', sequence_id, note_f1, note_with_offsets_f1, note_with_offsets_velocity_f1) return (note_precision, note_recall, note_f1, note_with_offsets_precision, note_with_offsets_recall, note_with_offsets_f1, note_with_offsets_velocity_precision, note_with_offsets_velocity_recall, note_with_offsets_velocity_f1, processed_frame_predictions)
def model_inference(model_dir, checkpoint_path, hparams, examples_path, output_dir, summary_writer, write_summary_every_step=True): """Runs inference for the given examples.""" tf.logging.info('model_dir=%s', model_dir) tf.logging.info('checkpoint_path=%s', checkpoint_path) tf.logging.info('examples_path=%s', examples_path) tf.logging.info('output_dir=%s', output_dir) estimator = train_util.create_estimator(model_dir, hparams) with tf.Graph().as_default(): num_dims = constants.MIDI_PITCHES if FLAGS.max_seconds_per_sequence: truncated_length = int( math.ceil((FLAGS.max_seconds_per_sequence * data.hparams_frames_per_second(hparams)))) else: truncated_length = 0 dataset = data.provide_batch(batch_size=1, examples=examples_path, hparams=hparams, is_training=False, truncated_length=truncated_length) # Define some metrics. (metrics_to_updates, metric_note_precision, metric_note_recall, metric_note_f1, metric_note_precision_with_offsets, metric_note_recall_with_offsets, metric_note_f1_with_offsets, metric_note_precision_with_offsets_velocity, metric_note_recall_with_offsets_velocity, metric_note_f1_with_offsets_velocity, metric_frame_labels, metric_frame_predictions) = infer_util.define_metrics(num_dims) summary_op = tf.summary.merge_all() if write_summary_every_step: global_step = tf.train.get_or_create_global_step() global_step_increment = global_step.assign_add(1) else: global_step = tf.constant( estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP)) global_step_increment = global_step iterator = dataset.make_initializable_iterator() next_record = iterator.get_next() with tf.Session() as sess: sess.run([ tf.initializers.global_variables(), tf.initializers.local_variables() ]) infer_times = [] num_frames = [] sess.run(iterator.initializer) while True: try: record = sess.run(next_record) except tf.errors.OutOfRangeError: break def input_fn(): return tf.data.Dataset.from_tensors(record) start_time = time.time() # TODO(fjord): This is a hack that allows us to keep using our existing # infer/scoring code with a tf.Estimator model. Ideally, we should # move things around so that we can use estimator.evaluate, which will # also be more efficient because it won't have to restore the checkpoint # for every example. prediction_list = list( estimator.predict(input_fn, checkpoint_path=checkpoint_path, yield_single_examples=False)) assert len(prediction_list) == 1 input_features = record[0] input_labels = record[1] filename = input_features.sequence_id[0] note_sequence = music_pb2.NoteSequence.FromString( input_labels.note_sequence[0]) labels = input_labels.labels[0] frame_probs = prediction_list[0]['frame_probs_flat'] onset_probs = prediction_list[0]['onset_probs_flat'] velocity_values = prediction_list[0]['velocity_values_flat'] offset_probs = prediction_list[0]['offset_probs_flat'] frame_predictions = frame_probs > FLAGS.frame_threshold if FLAGS.require_onset: onset_predictions = onset_probs > FLAGS.onset_threshold else: onset_predictions = None if FLAGS.use_offset: offset_predictions = offset_probs > FLAGS.offset_threshold else: offset_predictions = None sequence_prediction = sequences_lib.pianoroll_to_note_sequence( frame_predictions, frames_per_second=data.hparams_frames_per_second(hparams), min_duration_ms=0, min_midi_pitch=constants.MIN_MIDI_PITCH, onset_predictions=onset_predictions, offset_predictions=offset_predictions, velocity_values=velocity_values) end_time = time.time() infer_time = end_time - start_time infer_times.append(infer_time) num_frames.append(frame_probs.shape[0]) tf.logging.info( 'Infer time %f, frames %d, frames/sec %f, running average %f', infer_time, frame_probs.shape[0], frame_probs.shape[0] / infer_time, np.sum(num_frames) / np.sum(infer_times)) tf.logging.info('Scoring sequence %s', filename) def shift_notesequence(ns_time): return ns_time + hparams.backward_shift_amount_ms / 1000. sequence_label = sequences_lib.adjust_notesequence_times( note_sequence, shift_notesequence)[0] infer_util.score_sequence( sess, global_step_increment, metrics_to_updates, metric_note_precision, metric_note_recall, metric_note_f1, metric_note_precision_with_offsets, metric_note_recall_with_offsets, metric_note_f1_with_offsets, metric_note_precision_with_offsets_velocity, metric_note_recall_with_offsets_velocity, metric_note_f1_with_offsets_velocity, metric_frame_labels, metric_frame_predictions, frame_labels=labels, sequence_prediction=sequence_prediction, frames_per_second=data.hparams_frames_per_second(hparams), sequence_label=sequence_label, sequence_id=filename) if write_summary_every_step: # Make filenames UNIX-friendly. filename_safe = filename.decode('utf-8').replace( '/', '_').replace(':', '.') output_file = os.path.join(output_dir, filename_safe + '.mid') tf.logging.info('Writing inferred midi file to %s', output_file) midi_io.sequence_proto_to_midi_file( sequence_prediction, output_file) label_output_file = os.path.join( output_dir, filename_safe + '_label.mid') tf.logging.info('Writing label midi file to %s', label_output_file) midi_io.sequence_proto_to_midi_file( sequence_label, label_output_file) # Also write a pianoroll showing acoustic model output vs labels. pianoroll_output_file = os.path.join( output_dir, filename_safe + '_pianoroll.png') tf.logging.info('Writing acoustic logit/label file to %s', pianoroll_output_file) with tf.gfile.GFile(pianoroll_output_file, mode='w') as f: scipy.misc.imsave( f, infer_util.posterior_pianoroll_image( frame_probs, sequence_prediction, labels, overlap=True, frames_per_second=data. hparams_frames_per_second(hparams))) summary = sess.run(summary_op) summary_writer.add_summary(summary, sess.run(global_step)) summary_writer.flush() if not write_summary_every_step: # Only write the summary variables for the final step. summary = sess.run(summary_op) summary_writer.add_summary(summary, sess.run(global_step)) summary_writer.flush()
def _calculate_metrics_py( frame_predictions, onset_predictions, offset_predictions, velocity_values, sequence_label_str, frame_labels, sequence_id, hparams): """Python logic for calculating metrics on a single example.""" tf.logging.info('Calculating metrics for %s with length %d', sequence_id, frame_labels.shape[0]) if not hparams.predict_onset_threshold: onset_predictions = None if not hparams.predict_offset_threshold: offset_predictions = None sequence_prediction = sequences_lib.pianoroll_to_note_sequence( frames=frame_predictions, frames_per_second=data.hparams_frames_per_second(hparams), min_duration_ms=0, min_midi_pitch=constants.MIN_MIDI_PITCH, onset_predictions=onset_predictions, offset_predictions=offset_predictions, velocity_values=velocity_values) sequence_label = music_pb2.NoteSequence.FromString(sequence_label_str) if hparams.backward_shift_amount_ms: def shift_notesequence(ns_time): return ns_time + hparams.backward_shift_amount_ms / 1000. shifted_sequence_label, skipped_notes = ( sequences_lib.adjust_notesequence_times(sequence_label, shift_notesequence)) assert skipped_notes == 0 sequence_label = shifted_sequence_label est_intervals, est_pitches, est_velocities = ( infer_util.sequence_to_valued_intervals(sequence_prediction)) ref_intervals, ref_pitches, ref_velocities = ( infer_util.sequence_to_valued_intervals(sequence_label)) note_precision, note_recall, note_f1, _ = ( mir_eval.transcription.precision_recall_f1_overlap( ref_intervals, pretty_midi.note_number_to_hz(ref_pitches), est_intervals, pretty_midi.note_number_to_hz(est_pitches), offset_ratio=None)) (note_with_offsets_precision, note_with_offsets_recall, note_with_offsets_f1, _) = ( mir_eval.transcription.precision_recall_f1_overlap( ref_intervals, pretty_midi.note_number_to_hz(ref_pitches), est_intervals, pretty_midi.note_number_to_hz(est_pitches))) (note_with_offsets_velocity_precision, note_with_offsets_velocity_recall, note_with_offsets_velocity_f1, _) = ( mir_eval.transcription_velocity.precision_recall_f1_overlap( ref_intervals=ref_intervals, ref_pitches=pretty_midi.note_number_to_hz(ref_pitches), ref_velocities=ref_velocities, est_intervals=est_intervals, est_pitches=pretty_midi.note_number_to_hz(est_pitches), est_velocities=est_velocities)) processed_frame_predictions = sequences_lib.sequence_to_pianoroll( sequence_prediction, frames_per_second=data.hparams_frames_per_second(hparams), min_pitch=constants.MIN_MIDI_PITCH, max_pitch=constants.MAX_MIDI_PITCH).active if processed_frame_predictions.shape[0] < frame_labels.shape[0]: # Pad transcribed frames with silence. pad_length = frame_labels.shape[0] - processed_frame_predictions.shape[0] processed_frame_predictions = np.pad(processed_frame_predictions, [(0, pad_length), (0, 0)], 'constant') elif processed_frame_predictions.shape[0] > frame_labels.shape[0]: # Truncate transcribed frames. processed_frame_predictions = ( processed_frame_predictions[:frame_labels.shape[0], :]) tf.logging.info( 'Metrics for %s: Note F1 %f, Note w/ offsets F1 %f, ' 'Note w/ offsets & velocity: %f', sequence_id, note_f1, note_with_offsets_f1, note_with_offsets_velocity_f1) return (note_precision, note_recall, note_f1, note_with_offsets_precision, note_with_offsets_recall, note_with_offsets_f1, note_with_offsets_velocity_precision, note_with_offsets_velocity_recall, note_with_offsets_velocity_f1, processed_frame_predictions)
def model_inference(model_fn, model_dir, checkpoint_path, data_fn, hparams, examples_path, output_dir, summary_writer, master, preprocess_examples, shuffle_examples): """Runs inference for the given examples.""" tf.logging.info('model_dir=%s', model_dir) tf.logging.info('checkpoint_path=%s', checkpoint_path) tf.logging.info('examples_path=%s', examples_path) tf.logging.info('output_dir=%s', output_dir) estimator = train_util.create_estimator(model_fn, model_dir, hparams, master=master) transcription_data = functools.partial( data_fn, examples=examples_path, preprocess_examples=preprocess_examples, is_training=False, shuffle_examples=shuffle_examples, skip_n_initial_records=0) input_fn = infer_util.labels_to_features_wrapper(transcription_data) start_time = time.time() infer_times = [] num_frames = [] file_num = 0 all_metrics = collections.defaultdict(list) for predictions in estimator.predict(input_fn, checkpoint_path=checkpoint_path, yield_single_examples=False): # Remove batch dimension for convenience. for k in predictions.keys(): if predictions[k].shape[0] != 1: raise ValueError( 'All predictions must have batch size 1, but shape of ' '{} was: {}'.format(k, +predictions[k].shape[0])) predictions[k] = predictions[k][0] end_time = time.time() infer_time = end_time - start_time infer_times.append(infer_time) num_frames.append(predictions['frame_predictions'].shape[0]) tf.logging.info( 'Infer time %f, frames %d, frames/sec %f, running average %f', infer_time, num_frames[-1], num_frames[-1] / infer_time, np.sum(num_frames) / np.sum(infer_times)) tf.logging.info('Scoring sequence %s', predictions['sequence_ids']) sequence_prediction = music_pb2.NoteSequence.FromString( predictions['sequence_predictions']) sequence_label = music_pb2.NoteSequence.FromString( predictions['sequence_labels']) # Make filenames UNIX-friendly. filename_chars = six.ensure_text(predictions['sequence_ids'], 'utf-8') filename_chars = [c if c.isalnum() else '_' for c in filename_chars] filename_safe = ''.join(filename_chars).rstrip() filename_safe = '{:04d}_{}'.format(file_num, filename_safe[:200]) file_num += 1 output_file = os.path.join(output_dir, filename_safe + '.mid') tf.logging.info('Writing inferred midi file to %s', output_file) midi_io.sequence_proto_to_midi_file(sequence_prediction, output_file) label_output_file = os.path.join(output_dir, filename_safe + '_label.mid') tf.logging.info('Writing label midi file to %s', label_output_file) midi_io.sequence_proto_to_midi_file(sequence_label, label_output_file) # Also write a pianoroll showing acoustic model output vs labels. pianoroll_output_file = os.path.join(output_dir, filename_safe + '_pianoroll.png') tf.logging.info('Writing acoustic logit/label file to %s', pianoroll_output_file) # Calculate frames based on the sequence. Includes any postprocessing done # to turn raw onsets/frames predictions into the final sequence. # TODO(fjord): This work is duplicated in metrics.py. sequence_frame_predictions = sequences_lib.sequence_to_pianoroll( sequence_prediction, frames_per_second=data.hparams_frames_per_second(hparams), min_pitch=constants.MIN_MIDI_PITCH, max_pitch=constants.MAX_MIDI_PITCH).active with tf.gfile.GFile(pianoroll_output_file, mode='w') as f: imageio.imwrite(f, infer_util.posterior_pianoroll_image( predictions['onset_probs'], predictions['onset_labels'], predictions['frame_probs'], predictions['frame_labels'], sequence_frame_predictions), format='png') # Update histogram and current scalar for metrics. with tf.Graph().as_default(), tf.Session().as_default(): for k, v in predictions.items(): if not k.startswith('metrics/'): continue all_metrics[k].extend(v) histogram_name = k + '_histogram' metric_summary = tf.summary.histogram(histogram_name, all_metrics[k]) summary_writer.add_summary(metric_summary.eval(), global_step=file_num) scalar_name = k metric_summary = tf.summary.scalar(scalar_name, np.mean(all_metrics[k])) summary_writer.add_summary(metric_summary.eval(), global_step=file_num) summary_writer.flush() start_time = time.time() # Write final mean values for all metrics. with tf.Graph().as_default(), tf.Session().as_default(): for k, v in all_metrics.items(): final_scalar_name = 'final/' + k metric_summary = tf.summary.scalar(final_scalar_name, np.mean(all_metrics[k])) summary_writer.add_summary(metric_summary.eval()) summary_writer.flush()
def model_inference(model_fn, model_dir, checkpoint_path, hparams, examples_path, output_dir, summary_writer, master, preprocess_examples, write_summary_every_step=True): """Runs inference for the given examples.""" tf.logging.info('model_dir=%s', model_dir) tf.logging.info('checkpoint_path=%s', checkpoint_path) tf.logging.info('examples_path=%s', examples_path) tf.logging.info('output_dir=%s', output_dir) estimator = train_util.create_estimator( model_fn, model_dir, hparams, master=master) with tf.Graph().as_default(): num_dims = constants.MIDI_PITCHES dataset = data.provide_batch( examples=examples_path, preprocess_examples=preprocess_examples, hparams=hparams, is_training=False) # Define some metrics. (metrics_to_updates, metric_note_precision, metric_note_recall, metric_note_f1, metric_note_precision_with_offsets, metric_note_recall_with_offsets, metric_note_f1_with_offsets, metric_note_precision_with_offsets_velocity, metric_note_recall_with_offsets_velocity, metric_note_f1_with_offsets_velocity, metric_frame_labels, metric_frame_predictions) = infer_util.define_metrics(num_dims) summary_op = tf.summary.merge_all() if write_summary_every_step: global_step = tf.train.get_or_create_global_step() global_step_increment = global_step.assign_add(1) else: global_step = tf.constant( estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP)) global_step_increment = global_step iterator = dataset.make_initializable_iterator() next_record = iterator.get_next() with tf.Session() as sess: sess.run([ tf.initializers.global_variables(), tf.initializers.local_variables() ]) infer_times = [] num_frames = [] sess.run(iterator.initializer) while True: try: record = sess.run(next_record) except tf.errors.OutOfRangeError: break def input_fn(params): del params return tf.data.Dataset.from_tensors(record) start_time = time.time() # TODO(fjord): This is a hack that allows us to keep using our existing # infer/scoring code with a tf.Estimator model. Ideally, we should # move things around so that we can use estimator.evaluate, which will # also be more efficient because it won't have to restore the checkpoint # for every example. prediction_list = list( estimator.predict( input_fn, checkpoint_path=checkpoint_path, yield_single_examples=False)) assert len(prediction_list) == 1 input_features = record[0] input_labels = record[1] filename = input_features.sequence_id[0] note_sequence = music_pb2.NoteSequence.FromString( input_labels.note_sequence[0]) labels = input_labels.labels[0] frame_probs = prediction_list[0]['frame_probs'][0] frame_predictions = prediction_list[0]['frame_predictions'][0] onset_predictions = prediction_list[0]['onset_predictions'][0] velocity_values = prediction_list[0]['velocity_values'][0] offset_predictions = prediction_list[0]['offset_predictions'][0] if not FLAGS.require_onset: onset_predictions = None if not FLAGS.use_offset: offset_predictions = None sequence_prediction = sequences_lib.pianoroll_to_note_sequence( frame_predictions, frames_per_second=data.hparams_frames_per_second(hparams), min_duration_ms=0, min_midi_pitch=constants.MIN_MIDI_PITCH, onset_predictions=onset_predictions, offset_predictions=offset_predictions, velocity_values=velocity_values) end_time = time.time() infer_time = end_time - start_time infer_times.append(infer_time) num_frames.append(frame_predictions.shape[0]) tf.logging.info( 'Infer time %f, frames %d, frames/sec %f, running average %f', infer_time, frame_predictions.shape[0], frame_predictions.shape[0] / infer_time, np.sum(num_frames) / np.sum(infer_times)) tf.logging.info('Scoring sequence %s', filename) def shift_notesequence(ns_time): return ns_time + hparams.backward_shift_amount_ms / 1000. sequence_label = sequences_lib.adjust_notesequence_times( note_sequence, shift_notesequence)[0] infer_util.score_sequence( sess, global_step_increment, metrics_to_updates, metric_note_precision, metric_note_recall, metric_note_f1, metric_note_precision_with_offsets, metric_note_recall_with_offsets, metric_note_f1_with_offsets, metric_note_precision_with_offsets_velocity, metric_note_recall_with_offsets_velocity, metric_note_f1_with_offsets_velocity, metric_frame_labels, metric_frame_predictions, frame_labels=labels, sequence_prediction=sequence_prediction, frames_per_second=data.hparams_frames_per_second(hparams), sequence_label=sequence_label, sequence_id=filename) if write_summary_every_step: # Make filenames UNIX-friendly. filename_safe = filename.decode('utf-8').replace('/', '_').replace( ':', '.') output_file = os.path.join(output_dir, filename_safe + '.mid') tf.logging.info('Writing inferred midi file to %s', output_file) midi_io.sequence_proto_to_midi_file(sequence_prediction, output_file) label_output_file = os.path.join(output_dir, filename_safe + '_label.mid') tf.logging.info('Writing label midi file to %s', label_output_file) midi_io.sequence_proto_to_midi_file(sequence_label, label_output_file) # Also write a pianoroll showing acoustic model output vs labels. pianoroll_output_file = os.path.join(output_dir, filename_safe + '_pianoroll.png') tf.logging.info('Writing acoustic logit/label file to %s', pianoroll_output_file) with tf.gfile.GFile(pianoroll_output_file, mode='w') as f: scipy.misc.imsave( f, infer_util.posterior_pianoroll_image( frame_probs, sequence_prediction, labels, overlap=True, frames_per_second=data.hparams_frames_per_second(hparams))) summary = sess.run(summary_op) summary_writer.add_summary(summary, sess.run(global_step)) summary_writer.flush() if not write_summary_every_step: # Only write the summary variables for the final step. summary = sess.run(summary_op) summary_writer.add_summary(summary, sess.run(global_step)) summary_writer.flush()
tf.train.Feature(bytes_list=tf.train.BytesList( value=[fn.encode('utf-8')])), 'sequence': tf.train.Feature(bytes_list=tf.train.BytesList( value=[music_pb2.NoteSequence().SerializeToString()])), 'audio': tf.train.Feature(bytes_list=tf.train.BytesList(value=[wav_data])), })) to_process.append(example.SerializeToString()) print('Processing complete for', fn) session.run(iterator.initializer, {examples: to_process}) filenames, frame_logits, onset_logits = session.run( [batch.filenames, frame_probs_flat, onset_probs_flat]) print('Inference complete for', filenames[0]) frame_predictions = frame_logits > .5 onset_predictions = onset_logits > .5 sequence_prediction = infer_util.pianoroll_to_note_sequence( frame_predictions, frames_per_second=data.hparams_frames_per_second(hparams), min_duration_ms=0, onset_predictions=onset_predictions) midi_filename = (filenames[0] + '.mid').replace(' ', '_') ##todo midi_dir = {'./output/' + midi_filename} midi_io.sequence_proto_to_midi_file(sequence_prediction, midi_filename)
def _calculate_metrics_py(frame_probs, onset_probs, frame_predictions, onset_predictions, offset_predictions, velocity_values, sequence_label_str, frame_labels, sequence_id, hparams, min_pitch, max_pitch, onsets_only, restrict_to_pitch=None): """Python logic for calculating metrics on a single example.""" tf.logging.info('Calculating metrics for %s with length %d', sequence_id, frame_labels.shape[0]) sequence_prediction = infer_util.predict_sequence( frame_probs=frame_probs, onset_probs=onset_probs, frame_predictions=frame_predictions, onset_predictions=onset_predictions, offset_predictions=offset_predictions, velocity_values=velocity_values, min_pitch=min_pitch, hparams=hparams, onsets_only=onsets_only) sequence_label = music_pb2.NoteSequence.FromString(sequence_label_str) if hparams.backward_shift_amount_ms: def shift_notesequence(ns_time): return ns_time + hparams.backward_shift_amount_ms / 1000. shifted_sequence_label, skipped_notes = ( sequences_lib.adjust_notesequence_times(sequence_label, shift_notesequence)) assert skipped_notes == 0 sequence_label = shifted_sequence_label est_intervals, est_pitches, est_velocities = ( sequence_to_valued_intervals( sequence_prediction, restrict_to_pitch=restrict_to_pitch)) ref_intervals, ref_pitches, ref_velocities = ( sequence_to_valued_intervals( sequence_label, restrict_to_pitch=restrict_to_pitch)) processed_frame_predictions = sequences_lib.sequence_to_pianoroll( sequence_prediction, frames_per_second=data.hparams_frames_per_second(hparams), min_pitch=min_pitch, max_pitch=max_pitch).active if processed_frame_predictions.shape[0] < frame_labels.shape[0]: # Pad transcribed frames with silence. pad_length = frame_labels.shape[0] - processed_frame_predictions.shape[0] processed_frame_predictions = np.pad(processed_frame_predictions, [(0, pad_length), (0, 0)], 'constant') elif processed_frame_predictions.shape[0] > frame_labels.shape[0]: # Truncate transcribed frames. processed_frame_predictions = ( processed_frame_predictions[:frame_labels.shape[0], :]) if len(ref_pitches) == 0: tf.logging.info( 'Reference pitches were length 0, returning empty metrics for %s:', sequence_id) return tuple([[]] * 12 + [processed_frame_predictions]) note_precision, note_recall, note_f1, _ = ( mir_eval.transcription.precision_recall_f1_overlap( ref_intervals, pretty_midi.note_number_to_hz(ref_pitches), est_intervals, pretty_midi.note_number_to_hz(est_pitches), offset_ratio=None)) (note_with_velocity_precision, note_with_velocity_recall, note_with_velocity_f1, _) = ( mir_eval.transcription_velocity.precision_recall_f1_overlap( ref_intervals=ref_intervals, ref_pitches=pretty_midi.note_number_to_hz(ref_pitches), ref_velocities=ref_velocities, est_intervals=est_intervals, est_pitches=pretty_midi.note_number_to_hz(est_pitches), est_velocities=est_velocities, offset_ratio=None)) (note_with_offsets_precision, note_with_offsets_recall, note_with_offsets_f1, _) = ( mir_eval.transcription.precision_recall_f1_overlap( ref_intervals, pretty_midi.note_number_to_hz(ref_pitches), est_intervals, pretty_midi.note_number_to_hz(est_pitches))) (note_with_offsets_velocity_precision, note_with_offsets_velocity_recall, note_with_offsets_velocity_f1, _) = ( mir_eval.transcription_velocity.precision_recall_f1_overlap( ref_intervals=ref_intervals, ref_pitches=pretty_midi.note_number_to_hz(ref_pitches), ref_velocities=ref_velocities, est_intervals=est_intervals, est_pitches=pretty_midi.note_number_to_hz(est_pitches), est_velocities=est_velocities)) tf.logging.info( 'Metrics for %s: Note F1 %f, Note w/ velocity F1 %f, Note w/ offsets F1 %f, ' 'Note w/ offsets & velocity: %f', sequence_id, note_f1, note_with_velocity_f1, note_with_offsets_f1, note_with_offsets_velocity_f1) # Return 1-d tensors for the metrics return ([note_precision], [note_recall], [note_f1], [note_with_velocity_precision], [note_with_velocity_recall], [note_with_velocity_f1], [note_with_offsets_precision], [note_with_offsets_recall], [note_with_offsets_f1 ], [note_with_offsets_velocity_precision], [note_with_offsets_velocity_recall], [note_with_offsets_velocity_f1 ], [processed_frame_predictions])