예제 #1
0
def main(unused_argv):
    logging.set_verbosity(FLAGS.log)
    if not os.path.exists(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)
    for input_file in sorted(os.listdir(FLAGS.input_dir)):
        if not input_file.endswith('.wav'):
            continue
        wav_filename = input_file
        midi_filename = input_file.replace('.wav', '.mid')
        logging.info('Aligning %s to %s', midi_filename, wav_filename)

        samples = audio_io.load_audio(
            os.path.join(FLAGS.input_dir, wav_filename),
            align_fine_lib.SAMPLE_RATE)
        ns = midi_io.midi_file_to_sequence_proto(
            os.path.join(FLAGS.input_dir, midi_filename))

        aligned_ns, unused_stats = align_fine_lib.align_cpp(
            samples,
            align_fine_lib.SAMPLE_RATE,
            ns,
            align_fine_lib.CQT_HOP_LENGTH_FINE,
            sf2_path=FLAGS.sf2_path,
            penalty_mul=FLAGS.penalty_mul)

        midi_io.sequence_proto_to_midi_file(
            aligned_ns, os.path.join(FLAGS.output_dir, midi_filename))

    logging.info('Done')
예제 #2
0
    def save_midi(self, y_probs, y_true, epoch):
        frame_predictions = y_probs[0][0] > self.hparams.predict_frame_threshold
        onset_predictions = y_probs[1][0] > self.hparams.predict_onset_threshold
        offset_predictions = y_probs[2][
            0] > self.hparams.predict_offset_threshold
        active_onsets = y_probs[1][0] > self.hparams.active_onset_threshold

        sequence = sequence_prediction_util.predict_sequence(
            frame_predictions=frame_predictions,
            onset_predictions=onset_predictions,
            offset_predictions=offset_predictions,
            velocity_values=None,
            min_pitch=constants.MIN_MIDI_PITCH,
            hparams=self.hparams,
            instrument=1,
            active_onsets=active_onsets)
        sequence.notes.extend(
            sequence_prediction_util.predict_sequence(
                frame_predictions=y_true[0][0],
                onset_predictions=y_true[1][0],
                offset_predictions=y_true[2][0],
                velocity_values=None,
                min_pitch=constants.MIN_MIDI_PITCH,
                hparams=self.hparams,
                instrument=0).notes)
        midi_filename = f'{self.save_dir}/{epoch}_predicted.midi'
        midi_io.sequence_proto_to_midi_file(sequence, midi_filename)
예제 #3
0
  def testIsDrumDetection(self):
    """Verify that is_drum instruments are properly tracked.

    self.midi_is_drum_filename is a MIDI file containing two tracks
    set to channel 9 (is_drum == True). Each contains one NoteOn. This
    test is designed to catch a bug where the second track would lose
    is_drum, remapping the drum track to an instrument track.
    """
    sequence_proto = midi_io.midi_file_to_sequence_proto(
        self.midi_is_drum_filename)
    with tempfile.NamedTemporaryFile(prefix='MidiDrumTest') as temp_file:
      midi_io.sequence_proto_to_midi_file(sequence_proto, temp_file.name)
      midi_data1 = mido.MidiFile(filename=self.midi_is_drum_filename)
      # Use the file object when writing to the tempfile
      # to avoid permission error.
      midi_data2 = mido.MidiFile(file=temp_file)

    # Count number of channel 9 Note Ons.
    channel_counts = [0, 0]
    for index, midi_data in enumerate([midi_data1, midi_data2]):
      for event in midi_data:
        if (event.type == 'note_on' and
            event.velocity > 0 and event.channel == 9):
          channel_counts[index] += 1
    self.assertEqual(channel_counts, [2, 2])
예제 #4
0
def generate_midi(midi_data, total_seconds=10):
    primer_sequence = midi_io.midi_to_sequence_proto(midi_data)
    generate_request = generator_pb2.GenerateSequenceRequest()
    if len(primer_sequence.notes) > 4:
        estimated_tempo = midi_data.estimate_tempo()
        if estimated_tempo > 240:
            qpm = estimated_tempo / 2
        else:
            qpm = estimated_tempo
    else:
        qpm = 120
    primer_sequence.tempos[0].qpm = qpm
    generate_request.input_sequence.CopyFrom(primer_sequence)
    generate_section = (generate_request.generator_options.generate_sections.add())
    # Set the start time to begin on the next step after the last note ends.
    notes_by_end_time = sorted(primer_sequence.notes, key=lambda n: n.end_time)
    last_end_time = notes_by_end_time[-1].end_time if notes_by_end_time else 0
    generate_section.start_time_seconds = last_end_time + _steps_to_seconds(
            1, qpm)
    generate_section.end_time_seconds = total_seconds
    # generate_response = generator_map[generator_name].generate(generate_request)
    generate_response = basic_generator.generate(generate_request)
    output = tempfile.NamedTemporaryFile()
    midi_io.sequence_proto_to_midi_file(
          generate_response.generated_sequence, output.name)
    output.seek(0)
    return output
예제 #5
0
    def CheckReadWriteMidi(self, filename):
        """Test writing to a MIDI file and comparing it to the original Sequence."""

        # TODO(deck): The input MIDI file is opened in pretty-midi and
        # re-written to a temp file, sanitizing the MIDI data (reordering
        # note ons, etc). Issue 85 in the pretty-midi GitHub
        # (http://github.com/craffel/pretty-midi/issues/85) requests that
        # this sanitization be available outside of the context of a file
        # write. If that is implemented, this rewrite code should be
        # modified or deleted.

        # When writing to the temp file, use the file object itself instead of
        # file.name to avoid the permission error on Windows.
        with tempfile.NamedTemporaryFile(prefix='MidiIoTest') as rewrite_file:
            original_midi = pretty_midi.PrettyMIDI(filename)
            original_midi.write(rewrite_file)  # Use file object
            # Back the file position to top to reload the rewrite_file
            rewrite_file.seek(0)
            source_midi = pretty_midi.PrettyMIDI(
                rewrite_file)  # Use file object
            sequence_proto = midi_io.midi_to_sequence_proto(source_midi)

        # Translate the NoteSequence to MIDI and write to a file.
        with tempfile.NamedTemporaryFile(prefix='MidiIoTest') as temp_file:
            midi_io.sequence_proto_to_midi_file(sequence_proto, temp_file.name)
            # Read it back in and compare to source.
            created_midi = pretty_midi.PrettyMIDI(temp_file)  # Use file object

        self.CheckPrettyMidiAndSequence(created_midi, sequence_proto)
예제 #6
0
def main(argv):
    tf.logging.set_verbosity(FLAGS.log)

    if FLAGS.acoustic_checkpoint_filename:
        acoustic_checkpoint = os.path.join(
            os.path.expanduser(FLAGS.acoustic_run_dir), 'train',
            FLAGS.acoustic_checkpoint_filename)
    else:
        acoustic_checkpoint = tf.train.latest_checkpoint(
            os.path.join(os.path.expanduser(FLAGS.acoustic_run_dir), 'train'))

    hparams = tf_utils.merge_hparams(constants.DEFAULT_HPARAMS,
                                     model.get_default_hparams())
    hparams.parse(FLAGS.hparams)

    transcription_session = initialize_session(acoustic_checkpoint, hparams)

    for filename in argv[1:]:
        tf.logging.info('Starting transcription for %s...', filename)

        sequence_prediction = transcribe_audio(transcription_session, filename,
                                               FLAGS.frame_threshold,
                                               FLAGS.onset_threshold)

        midi_filename = filename + '.midi'
        midi_io.sequence_proto_to_midi_file(sequence_prediction, midi_filename)

        tf.logging.info('Transcription written to %s.', midi_filename)
예제 #7
0
    def testIsDrumDetection(self):
        """Verify that is_drum instruments are properly tracked.

    self.midi_is_drum_filename is a MIDI file containing two tracks
    set to channel 9 (is_drum == True). Each contains one NoteOn. This
    test is designed to catch a bug where the second track would lose
    is_drum, remapping the drum track to an instrument track.
    """
        sequence_proto = midi_io.midi_file_to_sequence_proto(
            self.midi_is_drum_filename)
        with tempfile.NamedTemporaryFile(prefix='MidiDrumTest') as temp_file:
            midi_io.sequence_proto_to_midi_file(sequence_proto, temp_file.name)
            midi_data1 = mido.MidiFile(filename=self.midi_is_drum_filename)
            # Use the file object when writing to the tempfile
            # to avoid permission error.
            midi_data2 = mido.MidiFile(file=temp_file)

        # Count number of channel 9 Note Ons.
        channel_counts = [0, 0]
        for index, midi_data in enumerate([midi_data1, midi_data2]):
            for event in midi_data:
                if (event.type == 'note_on' and event.velocity > 0
                        and event.channel == 9):
                    channel_counts[index] += 1
        self.assertEqual(channel_counts, [2, 2])
def main(argv):
  tf.logging.set_verbosity(FLAGS.log)

  if FLAGS.acoustic_checkpoint_filename:
    acoustic_checkpoint = os.path.join(
        os.path.expanduser(FLAGS.acoustic_run_dir), 'train',
        FLAGS.acoustic_checkpoint_filename)
  else:
    acoustic_checkpoint = tf.train.latest_checkpoint(
        os.path.join(os.path.expanduser(FLAGS.acoustic_run_dir), 'train'))

  hparams = tf_utils.merge_hparams(
      constants.DEFAULT_HPARAMS, model.get_default_hparams())
  hparams.parse(FLAGS.hparams)

  transcription_session = initialize_session(acoustic_checkpoint, hparams)

  for filename in argv[1:]:
    tf.logging.info('Starting transcription for %s...', filename)

    sequence_prediction = transcribe_audio(
        transcription_session, filename, FLAGS.frame_threshold,
        FLAGS.onset_threshold)

    midi_filename = filename + '.midi'
    midi_io.sequence_proto_to_midi_file(sequence_prediction, midi_filename)

    tf.logging.info('Transcription written to %s.', midi_filename)
예제 #9
0
def main(unused_argv):
  logging.set_verbosity(FLAGS.log)
  if not os.path.exists(FLAGS.output_dir):
    os.makedirs(FLAGS.output_dir)
  for input_file in sorted(os.listdir(FLAGS.input_dir)):
    if not input_file.endswith('.wav'):
      continue
    wav_filename = input_file
    midi_filename = input_file.replace('.wav', '.mid')
    logging.info('Aligning %s to %s', midi_filename, wav_filename)

    samples = audio_io.load_audio(
        os.path.join(FLAGS.input_dir, wav_filename), align_fine_lib.SAMPLE_RATE)
    ns = midi_io.midi_file_to_sequence_proto(
        os.path.join(FLAGS.input_dir, midi_filename))

    aligned_ns, unused_stats = align_fine_lib.align_cpp(
        samples,
        align_fine_lib.SAMPLE_RATE,
        ns,
        align_fine_lib.CQT_HOP_LENGTH_FINE,
        sf2_path=FLAGS.sf2_path,
        penalty_mul=FLAGS.penalty_mul)

    midi_io.sequence_proto_to_midi_file(
        aligned_ns, os.path.join(FLAGS.output_dir, midi_filename))

  logging.info('Done')
예제 #10
0
    def save_stack_midi(self, y_probs, y_true, epoch):
        frame_predictions = y_true[0][0]
        onset_predictions = y_true[1][0]
        offset_predictions = y_true[2][0]
        sequence = sequence_prediction_util.predict_sequence(
            frame_predictions=frame_predictions,
            onset_predictions=onset_predictions,
            offset_predictions=offset_predictions,
            velocity_values=None,
            min_pitch=constants.MIN_MIDI_PITCH,
            hparams=self.hparams,
            instrument=0)

        for i in range(3):
            # Output midi values for each stack (frames, onsets, offsets).
            sequence.notes.extend(
                sequence_prediction_util.predict_sequence(
                    frame_predictions=y_probs[i][0] > 0.5,
                    onset_predictions=None,
                    offset_predictions=None,
                    velocity_values=None,
                    min_pitch=constants.MIN_MIDI_PITCH,
                    hparams=self.hparams,
                    instrument=i + 1).notes)

        midi_filename = f'{self.save_dir}/{epoch}_stacks.midi'
        midi_io.sequence_proto_to_midi_file(sequence, midi_filename)
예제 #11
0
  def CheckReadWriteMidi(self, filename):
    """Test writing to a MIDI file and comparing it to the original Sequence."""

    # TODO(deck): The input MIDI file is opened in pretty-midi and
    # re-written to a temp file, sanitizing the MIDI data (reordering
    # note ons, etc). Issue 85 in the pretty-midi GitHub
    # (http://github.com/craffel/pretty-midi/issues/85) requests that
    # this sanitization be available outside of the context of a file
    # write. If that is implemented, this rewrite code should be
    # modified or deleted.

    # When writing to the temp file, use the file object itself instead of
    # file.name to avoid the permission error on Windows.
    with tempfile.NamedTemporaryFile(prefix='MidiIoTest') as rewrite_file:
      original_midi = pretty_midi.PrettyMIDI(filename)
      original_midi.write(rewrite_file)  # Use file object
      # Back the file position to top to reload the rewrite_file
      rewrite_file.seek(0)
      source_midi = pretty_midi.PrettyMIDI(rewrite_file)  # Use file object
      sequence_proto = midi_io.midi_to_sequence_proto(source_midi)

    # Translate the NoteSequence to MIDI and write to a file.
    with tempfile.NamedTemporaryFile(prefix='MidiIoTest') as temp_file:
      midi_io.sequence_proto_to_midi_file(sequence_proto, temp_file.name)
      # Read it back in and compare to source.
      created_midi = pretty_midi.PrettyMIDI(temp_file)  # Use file object

    self.CheckPrettyMidiAndSequence(created_midi, sequence_proto)
예제 #12
0
def infer(filename):
    # WAV 파일 Binary로 읽기
    wav = open(filename, 'rb')
    wav_data = wav.read()
    wav.close()

    tf.logging.info('User .WAV FIle %s length %s bytes', filename,
                    len(wav_data))

    ## 전처리
    # 청크로 분할 후, Protocol Buffers 로 변환
    to_process = []
    examples = list(
        audio_label_data_utils.process_record(wav_data=wav_data,
                                              ns=music_pb2.NoteSequence(),
                                              example_id=filename,
                                              min_length=0,
                                              max_length=-1,
                                              allow_empty_notesequence=True))

    # 분할된 버퍼를 시리얼라이즈
    to_process.append(examples[0].SerializeToString())

    #############################################################

    #시리얼라이즈한 버퍼를 iterator에 주입
    sess.run(iterator.initializer, {example: to_process})

    # Inference
    predictions = list(estimator.predict(input_fn,
                                         yield_single_examples=False))
    #가정 설정문으로 prediction size를 1로 보장
    assert len(predictions) == 1

    #예측 결과 불러오기
    frame_predictions = predictions[0]['frame_predictions'][0]
    onset_predictions = predictions[0]['onset_predictions'][0]  # 치는 순간
    velocity_values = predictions[0]['velocity_values'][0]  #강약

    #MIDI로 인코딩
    sequence_prediction = sequences_lib.pianoroll_to_note_sequence(
        frame_predictions,
        frames_per_second=data.hparams_frames_per_second(hparams),
        min_duration_ms=0,
        min_midi_pitch=constants.MIN_MIDI_PITCH,
        onset_predictions=onset_predictions,
        velocity_values=velocity_values)

    basename = os.path.split(os.path.splitext(filename)[0])[1] + '.mid'
    output_filename = os.path.join('', basename)

    midi_filename = (output_filename)
    midi_io.sequence_proto_to_midi_file(sequence_prediction, midi_filename)

    print('Program Ended, Your MIDI File is in', output_filename)

    sess.close()
예제 #13
0
    def process(files):
        for fn in files:
            print('**\n\n', fn, '\n\n**')
            with open(fn, 'rb', buffering=0) as f:
                wav_data = f.read()
            example_list = list(
                audio_label_data_utils.process_record(
                wav_data=wav_data,
                ns=music_pb2.NoteSequence(),
                example_id=fn,
                min_length=0,
                max_length=-1,
                allow_empty_notesequence=True))
            assert len(example_list) == 1
            to_process.append(example_list[0].SerializeToString())
            print('Processing complete for', fn)

            sess = tf.Session()

            sess.run([
                tf.initializers.global_variables(),
                tf.initializers.local_variables()
            ])

            sess.run(iterator.initializer, {examples: to_process})

            def transcription_data(params):
                del params
                return tf.data.Dataset.from_tensors(sess.run(next_record))


            input_fn = infer_util.labels_to_features_wrapper(transcription_data)

            #@title Run inference
            prediction_list = list(
                estimator.predict(
                    input_fn,
                    yield_single_examples=False))
            assert len(prediction_list) == 1

            # Ignore warnings caused by pyfluidsynth
            import warnings
            warnings.filterwarnings("ignore", category=DeprecationWarning) 

            sequence_prediction = music_pb2.NoteSequence.FromString(
                prediction_list[0]['sequence_predictions'][0])

            pathname = fn.split('/').pop()
            print('**\n\n', pathname, '\n\n**')
            midi_filename = '{outputs}/{file}.mid'.format(outputs=output,file=pathname)
            midi_io.sequence_proto_to_midi_file(sequence_prediction, midi_filename)
예제 #14
0
def inference(filename):
    # 오디오 파일(.wav) 읽기
    wav_file = open(filename, mode='rb')
    wav_data = wav_file.read()
    wav_file.close()
    
    print('User uploaded file "{name}" with length {length} bytes'.format(name=filename, length=len(wav_data)))

    # 청크로 분할 후 protobufs 포맷으로 데이터 생성
    to_process = []
    example_list = list(
    audio_label_data_utils.process_record(wav_data=wav_data, ns=music_pb2.NoteSequence(),
        example_id=filename, min_length=0, max_length=-1, allow_empty_notesequence=True))
    
    # Serialize
    to_process.append(example_list[0].SerializeToString())

    # 세션 실행
    sess.run(iterator.initializer, {examples: to_process})

    # 예측
    prediction_list = list(estimator.predict(input_fn, yield_single_examples=False))
    assert len(prediction_list) == 1

    # 예측 결과 데이터 가져오기
    frame_predictions = prediction_list[0]['frame_predictions'][0]
    onset_predictions = prediction_list[0]['onset_predictions'][0]
    velocity_values = prediction_list[0]['velocity_values'][0]

    # 예측 결과 데이터를 이용해서 미디 시퀀스 생성
    sequence_prediction = sequences_lib.pianoroll_to_note_sequence(
        frame_predictions,
        frames_per_second=data.hparams_frames_per_second(hparams),
        min_duration_ms=0,
        min_midi_pitch=constants.MIN_MIDI_PITCH,
        onset_predictions=onset_predictions,
        velocity_values=velocity_values)

    basename = os.path.split(os.path.splitext(filename)[0])[1] + '.mid'
    output_filename = os.path.join(env.MIDI_DIRECTORY, basename)

    # 미디 시퀀스를 파일로 내보내기
    midi_filename = (output_filename)
    midi_io.sequence_proto_to_midi_file(sequence_prediction, midi_filename)

    return basename
def main(argv):
  tf.logging.set_verbosity(FLAGS.log)

  config = configs.CONFIG_MAP[FLAGS.config]
  hparams = config.hparams
  # For this script, default to not using cudnn.
  hparams.use_cudnn = False
  hparams.parse(FLAGS.hparams)
  hparams.batch_size = 1
  hparams.truncated_length_secs = 0

  with tf.Graph().as_default():
    examples = tf.placeholder(tf.string, [None])

    dataset = data.provide_batch(
        examples=examples,
        preprocess_examples=True,
        hparams=hparams,
        is_training=False)

    estimator = train_util.create_estimator(config.model_fn,
                                            os.path.expanduser(FLAGS.model_dir),
                                            hparams)

    iterator = dataset.make_initializable_iterator()
    next_record = iterator.get_next()

    with tf.Session() as sess:
      sess.run([
          tf.initializers.global_variables(),
          tf.initializers.local_variables()
      ])

      for filename in argv[1:]:
        tf.logging.info('Starting transcription for %s...', filename)

        # The reason we bounce between two Dataset objects is so we can use
        # the data processing functionality in data.py without having to
        # construct all the Example protos in memory ahead of time or create
        # a temporary tfrecord file.
        tf.logging.info('Processing file...')
        sess.run(iterator.initializer, {examples: [create_example(filename)]})

        def input_fn(params):
          del params
          return tf.data.Dataset.from_tensors(sess.run(next_record))

        tf.logging.info('Running inference...')
        checkpoint_path = None
        if FLAGS.checkpoint_path:
          checkpoint_path = os.path.expanduser(FLAGS.checkpoint_path)
        prediction_list = list(
            estimator.predict(
                input_fn,
                checkpoint_path=checkpoint_path,
                yield_single_examples=False))
        assert len(prediction_list) == 1

        sequence_prediction = transcribe_audio(prediction_list[0], hparams)

        midi_filename = filename + '.midi'
        midi_io.sequence_proto_to_midi_file(sequence_prediction, midi_filename)

        tf.logging.info('Transcription written to %s.', midi_filename)
def run(argv, config_map, data_fn):
    """Create transcriptions."""
    tf.logging.set_verbosity(FLAGS.log)

    config = config_map[FLAGS.config]
    hparams = config.hparams
    # For this script, default to not using cudnn.
    hparams.use_cudnn = False
    hparams.parse(FLAGS.hparams)
    hparams.batch_size = 1
    hparams.truncated_length_secs = 0

    with tf.Graph().as_default():
        examples = tf.placeholder(tf.string, [None])

        dataset = data_fn(examples=examples,
                          preprocess_examples=True,
                          params=hparams,
                          is_training=False,
                          shuffle_examples=False,
                          skip_n_initial_records=0)

        estimator = train_util.create_estimator(
            config.model_fn, os.path.expanduser(FLAGS.model_dir), hparams)

        iterator = dataset.make_initializable_iterator()
        next_record = iterator.get_next()

        with tf.Session() as sess:
            sess.run([
                tf.initializers.global_variables(),
                tf.initializers.local_variables()
            ])

            for filename in argv[1:]:
                tf.logging.info('Starting transcription for %s...', filename)

                # The reason we bounce between two Dataset objects is so we can use
                # the data processing functionality in data.py without having to
                # construct all the Example protos in memory ahead of time or create
                # a temporary tfrecord file.
                tf.logging.info('Processing file...')
                sess.run(
                    iterator.initializer, {
                        examples: [
                            create_example(filename,
                                           FLAGS.load_audio_with_librosa)
                        ]
                    })

                def transcription_data(params):
                    del params
                    return tf.data.Dataset.from_tensors(sess.run(next_record))

                input_fn = infer_util.labels_to_features_wrapper(
                    transcription_data)

                tf.logging.info('Running inference...')
                checkpoint_path = None
                if FLAGS.checkpoint_path:
                    checkpoint_path = os.path.expanduser(FLAGS.checkpoint_path)
                prediction_list = list(
                    estimator.predict(input_fn,
                                      checkpoint_path=checkpoint_path,
                                      yield_single_examples=False))
                assert len(prediction_list) == 1

                sequence_prediction = music_pb2.NoteSequence.FromString(
                    prediction_list[0]['sequence_predictions'][0])

                midi_filename = filename + FLAGS.transcribed_file_suffix + '.midi'
                midi_io.sequence_proto_to_midi_file(sequence_prediction,
                                                    midi_filename)

                tf.logging.info('Transcription written to %s.', midi_filename)
def model_inference(acoustic_checkpoint, hparams, examples_path, run_dir):
  """Runs inference for the given examples."""
  tf.logging.info('acoustic_checkpoint=%s', acoustic_checkpoint)
  tf.logging.info('examples_path=%s', examples_path)
  tf.logging.info('run_dir=%s', run_dir)

  with tf.Graph().as_default():
    num_dims = constants.MIDI_PITCHES

    # Build the acoustic model within an 'acoustic' scope to isolate its
    # variables from the other models.
    with tf.variable_scope('acoustic'):
      truncated_length = 0
      if FLAGS.max_seconds_per_sequence:
        truncated_length = int(
            math.ceil((FLAGS.max_seconds_per_sequence *
                       data.hparams_frames_per_second(hparams))))
      acoustic_data_provider, _ = data.provide_batch(
          batch_size=1,
          examples=examples_path,
          hparams=hparams,
          is_training=False,
          truncated_length=truncated_length,
          include_note_sequences=True)

      _, _, data_labels, _, _ = model.get_model(
          acoustic_data_provider, hparams, is_training=False)

    # The checkpoints won't have the new scopes.
    acoustic_variables = {
        re.sub(r'^acoustic/', '', var.op.name): var
        for var in slim.get_variables(scope='acoustic/')
    }
    acoustic_restore = tf.train.Saver(acoustic_variables)

    onset_probs_flat = tf.get_default_graph().get_tensor_by_name(
        'acoustic/onsets/onset_probs_flat:0')
    frame_probs_flat = tf.get_default_graph().get_tensor_by_name(
        'acoustic/frame_probs_flat:0')
    offset_probs_flat = tf.get_default_graph().get_tensor_by_name(
        'acoustic/offsets/offset_probs_flat:0')
    velocity_values_flat = tf.get_default_graph().get_tensor_by_name(
        'acoustic/velocity/velocity_values_flat:0')

    # Define some metrics.
    (metrics_to_updates, metric_note_precision, metric_note_recall,
     metric_note_f1, metric_note_precision_with_offsets,
     metric_note_recall_with_offsets, metric_note_f1_with_offsets,
     metric_note_precision_with_offsets_velocity,
     metric_note_recall_with_offsets_velocity,
     metric_note_f1_with_offsets_velocity, metric_frame_labels,
     metric_frame_predictions) = infer_util.define_metrics(num_dims)

    summary_op = tf.summary.merge_all()
    global_step = tf.contrib.framework.get_or_create_global_step()
    global_step_increment = global_step.assign_add(1)

    # Use a custom init function to restore the acoustic and language models
    # from their separate checkpoints.
    def init_fn(unused_self, sess):
      acoustic_restore.restore(sess, acoustic_checkpoint)

    scaffold = tf.train.Scaffold(init_fn=init_fn)
    session_creator = tf.train.ChiefSessionCreator(
        scaffold=scaffold, master=FLAGS.master)
    with tf.train.MonitoredSession(session_creator=session_creator) as sess:
      tf.logging.info('running session')
      summary_writer = tf.summary.FileWriter(
          logdir=run_dir, graph=sess.graph)

      tf.logging.info('Inferring for %d batches',
                      acoustic_data_provider.num_batches)
      infer_times = []
      num_frames = []
      for unused_i in range(acoustic_data_provider.num_batches):
        start_time = time.time()
        (labels, filenames, note_sequences, frame_probs, onset_probs,
         offset_probs, velocity_values) = sess.run([
             data_labels,
             acoustic_data_provider.filenames,
             acoustic_data_provider.note_sequences,
             frame_probs_flat,
             onset_probs_flat,
             offset_probs_flat,
             velocity_values_flat,
         ])
        # We expect these all to be length 1 because batch size is 1.
        assert len(filenames) == len(note_sequences) == 1
        # These should be the same length and have been flattened.
        assert len(labels) == len(frame_probs) == len(onset_probs)

        frame_predictions = frame_probs > FLAGS.frame_threshold
        if FLAGS.require_onset:
          onset_predictions = onset_probs > FLAGS.onset_threshold
        else:
          onset_predictions = None

        if FLAGS.use_offset:
          offset_predictions = offset_probs > FLAGS.offset_threshold
        else:
          offset_predictions = None

        sequence_prediction = sequences_lib.pianoroll_to_note_sequence(
            frame_predictions,
            frames_per_second=data.hparams_frames_per_second(hparams),
            min_duration_ms=0,
            min_midi_pitch=constants.MIN_MIDI_PITCH,
            onset_predictions=onset_predictions,
            offset_predictions=offset_predictions,
            velocity_values=velocity_values)

        end_time = time.time()
        infer_time = end_time - start_time
        infer_times.append(infer_time)
        num_frames.append(frame_probs.shape[0])
        tf.logging.info(
            'Infer time %f, frames %d, frames/sec %f, running average %f',
            infer_time, frame_probs.shape[0], frame_probs.shape[0] / infer_time,
            np.sum(num_frames) / np.sum(infer_times))

        tf.logging.info('Scoring sequence %s', filenames[0])

        def shift_notesequence(ns_time):
          return ns_time + hparams.backward_shift_amount_ms / 1000.

        sequence_label = infer_util.score_sequence(
            sess,
            global_step_increment,
            summary_op,
            summary_writer,
            metrics_to_updates,
            metric_note_precision,
            metric_note_recall,
            metric_note_f1,
            metric_note_precision_with_offsets,
            metric_note_recall_with_offsets,
            metric_note_f1_with_offsets,
            metric_note_precision_with_offsets_velocity,
            metric_note_recall_with_offsets_velocity,
            metric_note_f1_with_offsets_velocity,
            metric_frame_labels,
            metric_frame_predictions,
            frame_labels=labels,
            sequence_prediction=sequence_prediction,
            frames_per_second=data.hparams_frames_per_second(hparams),
            sequence_label=sequences_lib.adjust_notesequence_times(
                music_pb2.NoteSequence.FromString(note_sequences[0]),
                shift_notesequence)[0],
            sequence_id=filenames[0])

        # Make filenames UNIX-friendly.
        filename = filenames[0].decode('utf-8').replace('/', '_').replace(
            ':', '.')
        output_file = os.path.join(run_dir, filename + '.mid')
        tf.logging.info('Writing inferred midi file to %s', output_file)
        midi_io.sequence_proto_to_midi_file(sequence_prediction, output_file)

        label_output_file = os.path.join(run_dir, filename + '_label.mid')
        tf.logging.info('Writing label midi file to %s', label_output_file)
        midi_io.sequence_proto_to_midi_file(sequence_label, label_output_file)

        # Also write a pianoroll showing acoustic model output vs labels.
        pianoroll_output_file = os.path.join(run_dir,
                                             filename + '_pianoroll.png')
        tf.logging.info('Writing acoustic logit/label file to %s',
                        pianoroll_output_file)
        with tf.gfile.GFile(pianoroll_output_file, mode='w') as f:
          scipy.misc.imsave(
              f,
              infer_util.posterior_pianoroll_image(
                  frame_probs,
                  sequence_prediction,
                  labels,
                  overlap=True,
                  frames_per_second=data.hparams_frames_per_second(hparams)))

        summary_writer.flush()
예제 #18
0
    def call(self, input_list, **kwargs):
        """
        Convert note croppings and their corresponding timbre
        predictions to a pianoroll that
        we can multiply by the melodic predictions.
        :param input_list: note_croppings, timbre_probs, pianorolls
        :return: a pianoroll with shape:
        (batches, pianoroll_length, 88, timbre_num_classes + 1)
        """
        batched_note_croppings, batched_timbre_probs, batched_pianorolls = input_list

        pianoroll_list = []
        for batch_idx in range(K.int_shape(batched_note_croppings)[0]):
            note_croppings = batched_note_croppings[batch_idx]
            timbre_probs = batched_timbre_probs[batch_idx]

            pianorolls = K.zeros(
                shape=(K.int_shape(batched_pianorolls[batch_idx])[0],
                       constants.MIDI_PITCHES,
                       self.hparams.timbre_num_classes))
            ones = np.ones(
                shape=(K.int_shape(batched_pianorolls[batch_idx])[0],
                       constants.MIDI_PITCHES,
                       self.hparams.timbre_num_classes))

            for i, note_cropping in enumerate(note_croppings):
                cropping = NoteCropping(*note_cropping)
                pitch = cropping.pitch - constants.MIN_MIDI_PITCH
                if cropping.end_idx < 0:
                    # Don't fill padded notes.
                    continue
                start_idx = K.cast(cropping.start_idx / self.hparams.spec_hop_length, 'int64')
                end_idx = K.cast(cropping.end_idx / self.hparams.spec_hop_length, 'int64')
                pitch_mask = K.cast_to_floatx(tf.one_hot(pitch, constants.MIDI_PITCHES))
                end_time_mask = K.cast(tf.sequence_mask(
                    end_idx,
                    maxlen=K.int_shape(batched_pianorolls[batch_idx])[0]
                ), tf.float32)
                start_time_mask = K.cast(tf.math.logical_not(tf.sequence_mask(
                    start_idx,
                    maxlen=K.int_shape(batched_pianorolls[batch_idx])[0]
                )), tf.float32)
                time_mask = start_time_mask * end_time_mask
                # Constant time for the pitch mask.
                pitch_mask = K.expand_dims(K.expand_dims(pitch_mask, 0))
                # Constant pitch for the time mask.
                time_mask = K.expand_dims(K.expand_dims(time_mask, 1))
                mask = ones * pitch_mask
                mask = mask * time_mask
                cropped_probs = mask * (timbre_probs[i])
                if K.learning_phase() == 1:
                    # For training, this is necessary for the gradient.
                    pianorolls = pianorolls + cropped_probs
                else:
                    # For testing, this is faster.
                    pianorolls.assign_add(cropped_probs)

            frame_predictions = pianorolls > self.hparams.multiple_instruments_threshold
            sequence = sequence_prediction_util.predict_multi_sequence(
                frame_predictions=frame_predictions,
                min_pitch=constants.MIN_MIDI_PITCH,
                hparams=self.hparams)
            midi_filename = (
                f'./out/{batch_idx}-of-{K.int_shape(batched_note_croppings)[0]}.midi'
            )
            midi_io.sequence_proto_to_midi_file(sequence, midi_filename)
            # Make time the first dimension.
            pianoroll_list.append(pianorolls)

        return tf.convert_to_tensor(pianoroll_list)
예제 #19
0
파일: demo3.py 프로젝트: LegendFC/InsMaster
                value=[fn.encode('utf-8')])),
            'sequence':
            tf.train.Feature(bytes_list=tf.train.BytesList(
                value=[music_pb2.NoteSequence().SerializeToString()])),
            'audio':
            tf.train.Feature(bytes_list=tf.train.BytesList(value=[wav_data])),
        }))
    to_process.append(example.SerializeToString())
    print('Processing complete for', fn)

session.run(iterator.initializer, {examples: to_process})

filenames, frame_logits, onset_logits = session.run(
    [batch.filenames, frame_probs_flat, onset_probs_flat])

print('Inference complete for', filenames[0])

frame_predictions = frame_logits > .5

onset_predictions = onset_logits > .5

sequence_prediction = infer_util.pianoroll_to_note_sequence(
    frame_predictions,
    frames_per_second=data.hparams_frames_per_second(hparams),
    min_duration_ms=0,
    onset_predictions=onset_predictions)

midi_filename = (filenames[0] + '.mid').replace(' ', '_')  ##todo
midi_dir = {'./output/' + midi_filename}
midi_io.sequence_proto_to_midi_file(sequence_prediction, midi_filename)
def main(argv):
    tf.logging.set_verbosity(FLAGS.log)

    hparams = tf_utils.merge_hparams(constants.DEFAULT_HPARAMS,
                                     model.get_default_hparams())
    # For this script, default to not using cudnn.
    hparams.use_cudnn = False
    hparams.parse(FLAGS.hparams)
    hparams.batch_size = 1

    with tf.Graph().as_default():
        examples = tf.placeholder(tf.string, [None])

        dataset = data.provide_batch(batch_size=1,
                                     examples=examples,
                                     hparams=hparams,
                                     is_training=False,
                                     truncated_length=0)

        estimator = train_util.create_estimator(
            os.path.expanduser(FLAGS.model_dir), hparams)

        iterator = dataset.make_initializable_iterator()
        next_record = iterator.get_next()

        with tf.Session() as sess:
            sess.run([
                tf.initializers.global_variables(),
                tf.initializers.local_variables()
            ])

            for filename in argv[1:]:
                tf.logging.info('Starting transcription for %s...', filename)

                # The reason we bounce between two Dataset objects is so we can use
                # the data processing functionality in data.py without having to
                # construct all the Example protos in memory ahead of time or create
                # a temporary tfrecord file.
                tf.logging.info('Processing file...')
                sess.run(iterator.initializer,
                         {examples: [create_example(filename)]})

                def input_fn():
                    return tf.data.Dataset.from_tensors(sess.run(next_record))

                tf.logging.info('Running inference...')
                checkpoint_path = None
                if FLAGS.checkpoint_path:
                    checkpoint_path = os.path.expanduser(FLAGS.checkpoint_path)
                prediction_list = list(
                    estimator.predict(input_fn,
                                      checkpoint_path=checkpoint_path,
                                      yield_single_examples=False))
                assert len(prediction_list) == 1

                sequence_prediction = transcribe_audio(prediction_list[0],
                                                       hparams,
                                                       FLAGS.frame_threshold,
                                                       FLAGS.onset_threshold)

                midi_filename = filename + '.midi'
                midi_io.sequence_proto_to_midi_file(sequence_prediction,
                                                    midi_filename)

                tf.logging.info('Transcription written to %s.', midi_filename)
def model_inference(model_dir,
                    checkpoint_path,
                    hparams,
                    examples_path,
                    output_dir,
                    summary_writer,
                    write_summary_every_step=True):
    """Runs inference for the given examples."""
    tf.logging.info('model_dir=%s', model_dir)
    tf.logging.info('checkpoint_path=%s', checkpoint_path)
    tf.logging.info('examples_path=%s', examples_path)
    tf.logging.info('output_dir=%s', output_dir)

    estimator = train_util.create_estimator(model_dir, hparams)

    with tf.Graph().as_default():
        num_dims = constants.MIDI_PITCHES

        if FLAGS.max_seconds_per_sequence:
            truncated_length = int(
                math.ceil((FLAGS.max_seconds_per_sequence *
                           data.hparams_frames_per_second(hparams))))
        else:
            truncated_length = 0

        dataset = data.provide_batch(batch_size=1,
                                     examples=examples_path,
                                     hparams=hparams,
                                     is_training=False,
                                     truncated_length=truncated_length)

        # Define some metrics.
        (metrics_to_updates, metric_note_precision, metric_note_recall,
         metric_note_f1, metric_note_precision_with_offsets,
         metric_note_recall_with_offsets, metric_note_f1_with_offsets,
         metric_note_precision_with_offsets_velocity,
         metric_note_recall_with_offsets_velocity,
         metric_note_f1_with_offsets_velocity, metric_frame_labels,
         metric_frame_predictions) = infer_util.define_metrics(num_dims)

        summary_op = tf.summary.merge_all()

        if write_summary_every_step:
            global_step = tf.train.get_or_create_global_step()
            global_step_increment = global_step.assign_add(1)
        else:
            global_step = tf.constant(
                estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP))
            global_step_increment = global_step

        iterator = dataset.make_initializable_iterator()
        next_record = iterator.get_next()
        with tf.Session() as sess:
            sess.run([
                tf.initializers.global_variables(),
                tf.initializers.local_variables()
            ])

            infer_times = []
            num_frames = []

            sess.run(iterator.initializer)
            while True:
                try:
                    record = sess.run(next_record)
                except tf.errors.OutOfRangeError:
                    break

                def input_fn():
                    return tf.data.Dataset.from_tensors(record)

                start_time = time.time()

                # TODO(fjord): This is a hack that allows us to keep using our existing
                # infer/scoring code with a tf.Estimator model. Ideally, we should
                # move things around so that we can use estimator.evaluate, which will
                # also be more efficient because it won't have to restore the checkpoint
                # for every example.
                prediction_list = list(
                    estimator.predict(input_fn,
                                      checkpoint_path=checkpoint_path,
                                      yield_single_examples=False))
                assert len(prediction_list) == 1

                input_features = record[0]
                input_labels = record[1]

                filename = input_features.sequence_id[0]
                note_sequence = music_pb2.NoteSequence.FromString(
                    input_labels.note_sequence[0])
                labels = input_labels.labels[0]
                frame_probs = prediction_list[0]['frame_probs_flat']
                onset_probs = prediction_list[0]['onset_probs_flat']
                velocity_values = prediction_list[0]['velocity_values_flat']
                offset_probs = prediction_list[0]['offset_probs_flat']

                frame_predictions = frame_probs > FLAGS.frame_threshold
                if FLAGS.require_onset:
                    onset_predictions = onset_probs > FLAGS.onset_threshold
                else:
                    onset_predictions = None

                if FLAGS.use_offset:
                    offset_predictions = offset_probs > FLAGS.offset_threshold
                else:
                    offset_predictions = None

                sequence_prediction = sequences_lib.pianoroll_to_note_sequence(
                    frame_predictions,
                    frames_per_second=data.hparams_frames_per_second(hparams),
                    min_duration_ms=0,
                    min_midi_pitch=constants.MIN_MIDI_PITCH,
                    onset_predictions=onset_predictions,
                    offset_predictions=offset_predictions,
                    velocity_values=velocity_values)

                end_time = time.time()
                infer_time = end_time - start_time
                infer_times.append(infer_time)
                num_frames.append(frame_probs.shape[0])
                tf.logging.info(
                    'Infer time %f, frames %d, frames/sec %f, running average %f',
                    infer_time, frame_probs.shape[0],
                    frame_probs.shape[0] / infer_time,
                    np.sum(num_frames) / np.sum(infer_times))

                tf.logging.info('Scoring sequence %s', filename)

                def shift_notesequence(ns_time):
                    return ns_time + hparams.backward_shift_amount_ms / 1000.

                sequence_label = sequences_lib.adjust_notesequence_times(
                    note_sequence, shift_notesequence)[0]
                infer_util.score_sequence(
                    sess,
                    global_step_increment,
                    metrics_to_updates,
                    metric_note_precision,
                    metric_note_recall,
                    metric_note_f1,
                    metric_note_precision_with_offsets,
                    metric_note_recall_with_offsets,
                    metric_note_f1_with_offsets,
                    metric_note_precision_with_offsets_velocity,
                    metric_note_recall_with_offsets_velocity,
                    metric_note_f1_with_offsets_velocity,
                    metric_frame_labels,
                    metric_frame_predictions,
                    frame_labels=labels,
                    sequence_prediction=sequence_prediction,
                    frames_per_second=data.hparams_frames_per_second(hparams),
                    sequence_label=sequence_label,
                    sequence_id=filename)

                if write_summary_every_step:
                    # Make filenames UNIX-friendly.
                    filename_safe = filename.decode('utf-8').replace(
                        '/', '_').replace(':', '.')
                    output_file = os.path.join(output_dir,
                                               filename_safe + '.mid')
                    tf.logging.info('Writing inferred midi file to %s',
                                    output_file)
                    midi_io.sequence_proto_to_midi_file(
                        sequence_prediction, output_file)

                    label_output_file = os.path.join(
                        output_dir, filename_safe + '_label.mid')
                    tf.logging.info('Writing label midi file to %s',
                                    label_output_file)
                    midi_io.sequence_proto_to_midi_file(
                        sequence_label, label_output_file)

                    # Also write a pianoroll showing acoustic model output vs labels.
                    pianoroll_output_file = os.path.join(
                        output_dir, filename_safe + '_pianoroll.png')
                    tf.logging.info('Writing acoustic logit/label file to %s',
                                    pianoroll_output_file)
                    with tf.gfile.GFile(pianoroll_output_file, mode='w') as f:
                        scipy.misc.imsave(
                            f,
                            infer_util.posterior_pianoroll_image(
                                frame_probs,
                                sequence_prediction,
                                labels,
                                overlap=True,
                                frames_per_second=data.
                                hparams_frames_per_second(hparams)))

                    summary = sess.run(summary_op)
                    summary_writer.add_summary(summary, sess.run(global_step))
                    summary_writer.flush()

            if not write_summary_every_step:
                # Only write the summary variables for the final step.
                summary = sess.run(summary_op)
                summary_writer.add_summary(summary, sess.run(global_step))
                summary_writer.flush()
예제 #22
0
def transcribe(data_fn,
               model_dir,
               model_type,
               path,
               file_suffix,
               hparams,
               load_full=False,
               qpm=None):
    if data_fn:
        transcription_data = data_fn(preprocess_examples=True,
                                     is_training=False,
                                     shuffle_examples=True,
                                     skip_n_initial_records=0,
                                     hparams=hparams)
    else:
        transcription_data = None

    model_wrapper = ModelWrapper(model_dir,
                                 model_type,
                                 dataset=transcription_data,
                                 batch_size=1,
                                 id_=hparams.model_id,
                                 hparams=hparams)
    model_wrapper.load_weights(load_full)

    if data_fn:
        while True:
            # This will exit when the dataset runs out.
            # Generally, just predict on filenames rather than
            # TFRecords so you don't use this code.
            if model_type is ModelType.MELODIC:
                x, _ = model_wrapper.generator.get()
                sequence_prediction = model_wrapper.predict_from_spec(x[0])
                midi_filename = path + file_suffix + '.midi'
                midi_io.sequence_proto_to_midi_file(sequence_prediction,
                                                    midi_filename)
            elif model_type is ModelType.TIMBRE:
                x, y = model_wrapper.generator.get()
                timbre_prediction = K.get_value(
                    model_wrapper.predict_from_spec(*x))[0]
                print(
                    f'True: {x[1][0][0]}{constants.FAMILY_IDX_STRINGS[np.argmax(y[0][0])]}. '
                    f'Predicted: {constants.FAMILY_IDX_STRINGS[timbre_prediction]}'
                )
    else:
        filenames = glob.glob(path)
        for filename in filenames:
            logging.info('Starting transcription for %s...', filename)

            samples, sr = librosa.load(filename, hparams.sample_rate)

            if model_type is ModelType.TIMBRE:
                spec = timbre_dataset_util.create_timbre_spectrogram(
                    samples, hparams)
                # Add "batch" and channel dims.
                spec = K.cast_to_floatx(tf.reshape(spec, (1, *spec.shape, 1)))
                timbre_prediction = K.get_value(
                    model_wrapper.predict_from_spec(spec))[0]
                print(
                    f'File: {filename}. '
                    f'Predicted: {constants.FAMILY_IDX_STRINGS[timbre_prediction]}'
                )
                continue
            elif model_type is ModelType.MELODIC:
                spec = samples_to_cqt(samples, hparams=hparams)
                if hparams.spec_log_amplitude:
                    spec = librosa.power_to_db(spec)

                # Add "batch" and channel dims.
                spec = tf.reshape(spec, (1, *spec.shape, 1))

                logging.info('Running inference...')
                sequence_prediction = model_wrapper.predict_from_spec(spec,
                                                                      qpm=qpm)
            else:
                melodic_spec = samples_to_cqt(samples, hparams=hparams)
                if hparams.spec_log_amplitude:
                    melodic_spec = librosa.power_to_db(melodic_spec)

                timbre_spec = timbre_dataset_util.create_timbre_spectrogram(
                    samples, hparams)

                # Add "batch" and channel dims.
                melodic_spec = tf.reshape(melodic_spec,
                                          (1, *melodic_spec.shape, 1))
                timbre_spec = tf.reshape(timbre_spec,
                                         (1, *timbre_spec.shape, 1))

                logging.info('Running inference...')

                if hparams.present_instruments:
                    present_instruments = K.expand_dims(
                        hparams.present_instruments, 0)
                else:
                    present_instruments = None

                sequence_prediction = (model_wrapper.predict_multi_sequence(
                    melodic_spec=melodic_spec,
                    timbre_spec=timbre_spec,
                    present_instruments=present_instruments,
                    qpm=qpm))
            midi_filename = filename + file_suffix + '.midi'
            midi_io.sequence_proto_to_midi_file(sequence_prediction,
                                                midi_filename)

            logging.info('Transcription written to %s.', midi_filename)
예제 #23
0
파일: infer.py 프로젝트: johnnyVR/magenta
def model_inference(model_fn, model_dir, checkpoint_path, data_fn, hparams,
                    examples_path, output_dir, summary_writer, master,
                    preprocess_examples, shuffle_examples):
    """Runs inference for the given examples."""
    tf.logging.info('model_dir=%s', model_dir)
    tf.logging.info('checkpoint_path=%s', checkpoint_path)
    tf.logging.info('examples_path=%s', examples_path)
    tf.logging.info('output_dir=%s', output_dir)

    estimator = train_util.create_estimator(model_fn,
                                            model_dir,
                                            hparams,
                                            master=master)

    transcription_data = functools.partial(
        data_fn,
        examples=examples_path,
        preprocess_examples=preprocess_examples,
        is_training=False,
        shuffle_examples=shuffle_examples,
        skip_n_initial_records=0)

    input_fn = infer_util.labels_to_features_wrapper(transcription_data)

    start_time = time.time()
    infer_times = []
    num_frames = []

    file_num = 0

    all_metrics = collections.defaultdict(list)

    for predictions in estimator.predict(input_fn,
                                         checkpoint_path=checkpoint_path,
                                         yield_single_examples=False):

        # Remove batch dimension for convenience.
        for k in predictions.keys():
            if predictions[k].shape[0] != 1:
                raise ValueError(
                    'All predictions must have batch size 1, but shape of '
                    '{} was: {}'.format(k, +predictions[k].shape[0]))
            predictions[k] = predictions[k][0]

        end_time = time.time()
        infer_time = end_time - start_time
        infer_times.append(infer_time)
        num_frames.append(predictions['frame_predictions'].shape[0])
        tf.logging.info(
            'Infer time %f, frames %d, frames/sec %f, running average %f',
            infer_time, num_frames[-1], num_frames[-1] / infer_time,
            np.sum(num_frames) / np.sum(infer_times))

        tf.logging.info('Scoring sequence %s', predictions['sequence_ids'])

        sequence_prediction = music_pb2.NoteSequence.FromString(
            predictions['sequence_predictions'])
        sequence_label = music_pb2.NoteSequence.FromString(
            predictions['sequence_labels'])

        # Make filenames UNIX-friendly.
        filename_chars = predictions['sequence_ids'].decode('utf-8')
        filename_chars = [c if c.isalnum() else '_' for c in filename_chars]
        filename_safe = ''.join(filename_chars).rstrip()
        filename_safe = '{:04d}_{}'.format(file_num, filename_safe[:200])
        file_num += 1
        output_file = os.path.join(output_dir, filename_safe + '.mid')
        tf.logging.info('Writing inferred midi file to %s', output_file)
        midi_io.sequence_proto_to_midi_file(sequence_prediction, output_file)

        label_output_file = os.path.join(output_dir,
                                         filename_safe + '_label.mid')
        tf.logging.info('Writing label midi file to %s', label_output_file)
        midi_io.sequence_proto_to_midi_file(sequence_label, label_output_file)

        # Also write a pianoroll showing acoustic model output vs labels.
        pianoroll_output_file = os.path.join(output_dir,
                                             filename_safe + '_pianoroll.png')
        tf.logging.info('Writing acoustic logit/label file to %s',
                        pianoroll_output_file)
        with tf.gfile.GFile(pianoroll_output_file, mode='w') as f:
            scipy.misc.imsave(
                f,
                infer_util.posterior_pianoroll_image(
                    predictions['frame_probs'], predictions['frame_labels']))

        # Update histogram and current scalar for metrics.
        with tf.Graph().as_default(), tf.Session().as_default():
            for k, v in predictions.items():
                if not k.startswith('metrics/'):
                    continue
                all_metrics[k].extend(v)
                histogram_name = 'histogram/' + k
                metric_summary = tf.summary.histogram(histogram_name,
                                                      tf.constant(
                                                          all_metrics[k],
                                                          name=histogram_name),
                                                      collections=[])
                summary_writer.add_summary(metric_summary.eval(),
                                           global_step=file_num)
                scalar_name = k
                metric_summary = tf.summary.scalar(scalar_name,
                                                   tf.constant(
                                                       np.mean(all_metrics[k]),
                                                       name=scalar_name),
                                                   collections=[])
                summary_writer.add_summary(metric_summary.eval(),
                                           global_step=file_num)
            summary_writer.flush()

        start_time = time.time()

    # Write final mean values for all metrics.
    with tf.Graph().as_default(), tf.Session().as_default():
        for k, v in all_metrics.items():
            final_scalar_name = 'final/' + k
            metric_summary = tf.summary.scalar(final_scalar_name,
                                               tf.constant(
                                                   np.mean(all_metrics[k]),
                                                   name=final_scalar_name),
                                               collections=[])
            summary_writer.add_summary(metric_summary.eval())
        summary_writer.flush()

    start_time = time.time()
예제 #24
0
파일: infer.py 프로젝트: tensorflow/magenta
def model_inference(model_fn,
                    model_dir,
                    checkpoint_path,
                    hparams,
                    examples_path,
                    output_dir,
                    summary_writer,
                    master,
                    preprocess_examples,
                    write_summary_every_step=True):
  """Runs inference for the given examples."""
  tf.logging.info('model_dir=%s', model_dir)
  tf.logging.info('checkpoint_path=%s', checkpoint_path)
  tf.logging.info('examples_path=%s', examples_path)
  tf.logging.info('output_dir=%s', output_dir)

  estimator = train_util.create_estimator(
      model_fn, model_dir, hparams, master=master)

  with tf.Graph().as_default():
    num_dims = constants.MIDI_PITCHES

    dataset = data.provide_batch(
        examples=examples_path,
        preprocess_examples=preprocess_examples,
        hparams=hparams,
        is_training=False)

    # Define some metrics.
    (metrics_to_updates, metric_note_precision, metric_note_recall,
     metric_note_f1, metric_note_precision_with_offsets,
     metric_note_recall_with_offsets, metric_note_f1_with_offsets,
     metric_note_precision_with_offsets_velocity,
     metric_note_recall_with_offsets_velocity,
     metric_note_f1_with_offsets_velocity, metric_frame_labels,
     metric_frame_predictions) = infer_util.define_metrics(num_dims)

    summary_op = tf.summary.merge_all()

    if write_summary_every_step:
      global_step = tf.train.get_or_create_global_step()
      global_step_increment = global_step.assign_add(1)
    else:
      global_step = tf.constant(
          estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP))
      global_step_increment = global_step

    iterator = dataset.make_initializable_iterator()
    next_record = iterator.get_next()
    with tf.Session() as sess:
      sess.run([
          tf.initializers.global_variables(),
          tf.initializers.local_variables()
      ])

      infer_times = []
      num_frames = []

      sess.run(iterator.initializer)
      while True:
        try:
          record = sess.run(next_record)
        except tf.errors.OutOfRangeError:
          break

        def input_fn(params):
          del params
          return tf.data.Dataset.from_tensors(record)

        start_time = time.time()

        # TODO(fjord): This is a hack that allows us to keep using our existing
        # infer/scoring code with a tf.Estimator model. Ideally, we should
        # move things around so that we can use estimator.evaluate, which will
        # also be more efficient because it won't have to restore the checkpoint
        # for every example.
        prediction_list = list(
            estimator.predict(
                input_fn,
                checkpoint_path=checkpoint_path,
                yield_single_examples=False))
        assert len(prediction_list) == 1

        input_features = record[0]
        input_labels = record[1]

        filename = input_features.sequence_id[0]
        note_sequence = music_pb2.NoteSequence.FromString(
            input_labels.note_sequence[0])
        labels = input_labels.labels[0]
        frame_probs = prediction_list[0]['frame_probs'][0]
        frame_predictions = prediction_list[0]['frame_predictions'][0]
        onset_predictions = prediction_list[0]['onset_predictions'][0]
        velocity_values = prediction_list[0]['velocity_values'][0]
        offset_predictions = prediction_list[0]['offset_predictions'][0]

        if not FLAGS.require_onset:
          onset_predictions = None

        if not FLAGS.use_offset:
          offset_predictions = None

        sequence_prediction = sequences_lib.pianoroll_to_note_sequence(
            frame_predictions,
            frames_per_second=data.hparams_frames_per_second(hparams),
            min_duration_ms=0,
            min_midi_pitch=constants.MIN_MIDI_PITCH,
            onset_predictions=onset_predictions,
            offset_predictions=offset_predictions,
            velocity_values=velocity_values)

        end_time = time.time()
        infer_time = end_time - start_time
        infer_times.append(infer_time)
        num_frames.append(frame_predictions.shape[0])
        tf.logging.info(
            'Infer time %f, frames %d, frames/sec %f, running average %f',
            infer_time, frame_predictions.shape[0],
            frame_predictions.shape[0] / infer_time,
            np.sum(num_frames) / np.sum(infer_times))

        tf.logging.info('Scoring sequence %s', filename)

        def shift_notesequence(ns_time):
          return ns_time + hparams.backward_shift_amount_ms / 1000.

        sequence_label = sequences_lib.adjust_notesequence_times(
            note_sequence, shift_notesequence)[0]
        infer_util.score_sequence(
            sess,
            global_step_increment,
            metrics_to_updates,
            metric_note_precision,
            metric_note_recall,
            metric_note_f1,
            metric_note_precision_with_offsets,
            metric_note_recall_with_offsets,
            metric_note_f1_with_offsets,
            metric_note_precision_with_offsets_velocity,
            metric_note_recall_with_offsets_velocity,
            metric_note_f1_with_offsets_velocity,
            metric_frame_labels,
            metric_frame_predictions,
            frame_labels=labels,
            sequence_prediction=sequence_prediction,
            frames_per_second=data.hparams_frames_per_second(hparams),
            sequence_label=sequence_label,
            sequence_id=filename)

        if write_summary_every_step:
          # Make filenames UNIX-friendly.
          filename_safe = filename.decode('utf-8').replace('/', '_').replace(
              ':', '.')
          output_file = os.path.join(output_dir, filename_safe + '.mid')
          tf.logging.info('Writing inferred midi file to %s', output_file)
          midi_io.sequence_proto_to_midi_file(sequence_prediction, output_file)

          label_output_file = os.path.join(output_dir,
                                           filename_safe + '_label.mid')
          tf.logging.info('Writing label midi file to %s', label_output_file)
          midi_io.sequence_proto_to_midi_file(sequence_label, label_output_file)

          # Also write a pianoroll showing acoustic model output vs labels.
          pianoroll_output_file = os.path.join(output_dir,
                                               filename_safe + '_pianoroll.png')
          tf.logging.info('Writing acoustic logit/label file to %s',
                          pianoroll_output_file)
          with tf.gfile.GFile(pianoroll_output_file, mode='w') as f:
            scipy.misc.imsave(
                f,
                infer_util.posterior_pianoroll_image(
                    frame_probs,
                    sequence_prediction,
                    labels,
                    overlap=True,
                    frames_per_second=data.hparams_frames_per_second(hparams)))

          summary = sess.run(summary_op)
          summary_writer.add_summary(summary, sess.run(global_step))
          summary_writer.flush()

      if not write_summary_every_step:
        # Only write the summary variables for the final step.
        summary = sess.run(summary_op)
        summary_writer.add_summary(summary, sess.run(global_step))
        summary_writer.flush()
def run_with_flags(melody_rnn_sequence_generator):
    """Generates melodies and saves them as MIDI files.

  Uses the options specified by the flags defined in this module. Intended to be
  called from the main function of one of the melody generator modules.

  Args:
    melody_rnn_sequence_generator: A MelodyRnnSequenceGenerator object specific
        to your model.
  """
    if not FLAGS.output_dir:
        tf.logging.fatal('--output_dir required')
        return

    FLAGS.output_dir = os.path.expanduser(FLAGS.output_dir)
    if FLAGS.primer_midi:
        FLAGS.primer_midi = os.path.expanduser(FLAGS.primer_midi)

    if not os.path.exists(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)

    primer_sequence = None
    qpm = FLAGS.qpm if FLAGS.qpm else constants.DEFAULT_QUARTERS_PER_MINUTE
    if FLAGS.primer_melody:
        primer_melody = melodies_lib.MonophonicMelody()
        primer_melody.from_event_list(ast.literal_eval(FLAGS.primer_melody))
        primer_sequence = primer_melody.to_sequence(qpm=qpm)
    elif FLAGS.primer_midi:
        primer_sequence = midi_io.midi_file_to_sequence_proto(
            FLAGS.primer_midi)
        if primer_sequence.tempos and primer_sequence.tempos[0].qpm:
            qpm = primer_sequence.tempos[0].qpm

    # Derive the total number of seconds to generate based on the QPM of the
    # priming sequence and the num_steps flag.
    total_seconds = _steps_to_seconds(FLAGS.num_steps, qpm)

    # Specify start/stop time for generation based on starting generation at the
    # end of the priming sequence and continuing until the sequence is num_steps
    # long.
    generate_request = generator_pb2.GenerateSequenceRequest()
    if primer_sequence:
        generate_request.input_sequence.CopyFrom(primer_sequence)
        generate_section = (
            generate_request.generator_options.generate_sections.add())
        # Set the start time to begin on the next step after the last note ends.
        notes_by_end_time = sorted(primer_sequence.notes,
                                   key=lambda n: n.end_time)
        last_end_time = notes_by_end_time[
            -1].end_time if notes_by_end_time else 0
        generate_section.start_time_seconds = last_end_time + _steps_to_seconds(
            1, qpm)
        generate_section.end_time_seconds = total_seconds

        if generate_section.start_time_seconds >= generate_section.end_time_seconds:
            tf.logging.fatal(
                'Priming sequence is longer than the total number of steps '
                'requested: Priming sequence length: %s, Generation length '
                'requested: %s', generate_section.start_time_seconds,
                total_seconds)
            return
    else:
        generate_section = (
            generate_request.generator_options.generate_sections.add())
        generate_section.start_time_seconds = 0
        generate_section.end_time_seconds = total_seconds
        generate_request.input_sequence.tempos.add().qpm = qpm
    tf.logging.debug('generate_request: %s', generate_request)

    # Make the generate request num_outputs times and save the output as midi
    # files.
    date_and_time = time.strftime('%Y-%m-%d_%H%M%S')
    digits = len(str(FLAGS.num_outputs))
    for i in range(FLAGS.num_outputs):
        generate_response = melody_rnn_sequence_generator.generate(
            generate_request)

        midi_filename = '%s_%s.mid' % (date_and_time, str(i + 1).zfill(digits))
        midi_path = os.path.join(FLAGS.output_dir, midi_filename)
        midi_io.sequence_proto_to_midi_file(
            generate_response.generated_sequence, midi_path)

    tf.logging.info('Wrote %d MIDI files to %s', FLAGS.num_outputs,
                    FLAGS.output_dir)