def initialize_session(acoustic_checkpoint, hparams):
  """Initializes a transcription session."""
  with tf.Graph().as_default():
    examples = tf.placeholder(tf.string, [None])

    hparams.batch_size = 1

    batch, iterator = data.provide_batch(
        batch_size=1,
        examples=examples,
        hparams=hparams,
        is_training=False,
        truncated_length=0)

    model.get_model(batch, hparams, is_training=False)

    session = tf.Session()
    saver = tf.train.Saver()
    saver.restore(session, acoustic_checkpoint)

    onset_probs_flat = tf.get_default_graph().get_tensor_by_name(
        'onsets/onset_probs_flat:0')
    frame_probs_flat = tf.get_default_graph().get_tensor_by_name(
        'frame_probs_flat:0')
    velocity_values_flat = tf.get_default_graph().get_tensor_by_name(
        'velocity/velocity_values_flat:0')

    return TranscriptionSession(
        session=session,
        examples=examples,
        iterator=iterator,
        onset_probs_flat=onset_probs_flat,
        frame_probs_flat=frame_probs_flat,
        velocity_values_flat=velocity_values_flat,
        hparams=hparams)
Exemple #2
0
def initialize_session(acoustic_checkpoint, hparams):
    """Initializes a transcription session."""
    with tf.Graph().as_default():
        examples = tf.placeholder(tf.string, [None])

        batch, iterator = data.provide_batch(batch_size=1,
                                             examples=examples,
                                             hparams=hparams,
                                             is_training=False,
                                             truncated_length=0)

        model.get_model(batch, hparams, is_training=False)

        session = tf.Session()
        saver = tf.train.Saver()
        saver.restore(session, acoustic_checkpoint)

        onset_probs_flat = tf.get_default_graph().get_tensor_by_name(
            'onsets/onset_probs_flat:0')
        frame_probs_flat = tf.get_default_graph().get_tensor_by_name(
            'frame_probs_flat:0')
        velocity_values_flat = tf.get_default_graph().get_tensor_by_name(
            'velocity/velocity_values_flat:0')

        return TranscriptionSession(session=session,
                                    examples=examples,
                                    iterator=iterator,
                                    onset_probs_flat=onset_probs_flat,
                                    frame_probs_flat=frame_probs_flat,
                                    velocity_values_flat=velocity_values_flat,
                                    hparams=hparams)
Exemple #3
0
    def __init__(self, path='onsets-frames'):
        """Load the Onset-Frames Model (arXiv:1710.11153 [cs.SD]) and pretrained weights from the path using tensorflow and magenta.

        Args:
            path (str, optional): The path the model weights. Defaults to 'onsets-frames'.
        """
        tf.disable_eager_execution()
        tf.disable_v2_behavior()
        tf.logging.set_verbosity(tf.logging.ERROR)

        self.config = configs.CONFIG_MAP['onsets_frames']
        self.hparams = self.config.hparams
        self.hparams.use_cudnn = False
        self.hparams.batch_size = 1
        self.checkpoint_dir = path

        self.examples = tf.placeholder(tf.string, [None])

        self.dataset = data.provide_batch(examples=self.examples,
                                          preprocess_examples=True,
                                          params=self.hparams,
                                          is_training=False,
                                          shuffle_examples=False,
                                          skip_n_initial_records=0)

        self.estimator = train_util.create_estimator(self.config.model_fn,
                                                     self.checkpoint_dir,
                                                     self.hparams)

        self.iterator = tf.data.make_initializable_iterator(self.dataset)
        self.next_record = self.iterator.get_next()
Exemple #4
0
def _get_data(examples_path, hparams, is_training):
  """Gets transcription data."""
  hparams_dict = hparams.values()
  return data.provide_batch(
      hparams.batch_size,
      examples=examples_path,
      hparams=hparams,
      truncated_length=hparams_dict.get('truncated_length', None),
      is_training=is_training)
def _get_data(examples_path, hparams, is_training):
    hparams_dict = hparams.values()
    batch, _ = data.provide_batch(hparams.batch_size,
                                  examples=examples_path,
                                  hparams=hparams,
                                  truncated_length=hparams_dict.get(
                                      'truncated_length', None),
                                  is_training=is_training)
    return batch
Exemple #6
0
def _get_data(examples_path, hparams, is_training):
    """Gets transcription data."""
    hparams_dict = hparams.values()
    return data.provide_batch(hparams.batch_size,
                              examples=examples_path,
                              hparams=hparams,
                              truncated_length=hparams_dict.get(
                                  'truncated_length', None),
                              is_training=is_training)
Exemple #7
0
def _get_data(examples_path, hparams, is_training):
  """Gets transcription data."""
  hparams_dict = hparams.values()
  batch, _ = data.provide_batch(
      hparams.batch_size,
      examples=examples_path,
      hparams=hparams,
      truncated_length=hparams_dict.get('truncated_length', None),
      is_training=is_training,
      include_note_sequences=hparams_dict.get('include_events', False))
  return batch
Exemple #8
0
def _get_data(examples_path, hparams, is_training):
  """Gets transcription data."""
  hparams_dict = hparams.values()
  batch, _ = data.provide_batch(
      hparams.batch_size,
      examples=examples_path,
      hparams=hparams,
      truncated_length=hparams_dict.get('truncated_length', None),
      is_training=is_training,
      include_note_sequences=hparams_dict.get('include_events', False))
  return batch
Exemple #9
0
def _get_data(examples,
              preprocess_examples,
              params,
              is_training,
              semisupervised_configs=None):
    """Gets transcription data."""
    return data.provide_batch(examples=examples,
                              preprocess_examples=preprocess_examples,
                              hparams=params,
                              is_training=is_training,
                              semisupervised_configs=semisupervised_configs)
Exemple #10
0
    def _ValidateProvideBatch(self,
                              examples,
                              truncated_length,
                              batch_size,
                              expected_inputs,
                              feed_dict=None):
        """Tests for correctness of batches."""
        hparams = copy.deepcopy(configs.DEFAULT_HPARAMS)
        hparams.batch_size = batch_size
        hparams.truncated_length_secs = (
            truncated_length / data.hparams_frames_per_second(hparams))

        with self.test_session() as sess:
            dataset = data.provide_batch(examples=examples,
                                         preprocess_examples=True,
                                         params=hparams,
                                         is_training=False,
                                         shuffle_examples=False,
                                         skip_n_initial_records=0)
            iterator = dataset.make_initializable_iterator()
            next_record = iterator.get_next()
            sess.run([
                tf.initializers.local_variables(),
                tf.initializers.global_variables(), iterator.initializer
            ],
                     feed_dict=feed_dict)
            for i in range(0, len(expected_inputs), batch_size):
                # Wait to ensure example is pre-processed.
                time.sleep(0.1)
                features, labels = sess.run(next_record)
                inputs = [
                    features.spec, labels.labels, features.length,
                    features.sequence_id
                ]
                max_length = np.max(inputs[2])
                for j in range(batch_size):
                    # Add batch padding if needed.
                    input_length = expected_inputs[i + j][2]
                    if input_length < max_length:
                        expected_inputs[i + j] = list(expected_inputs[i + j])
                        pad_amt = max_length - input_length
                        expected_inputs[i + j][0] = np.pad(
                            expected_inputs[i + j][0], [(0, pad_amt), (0, 0)],
                            'constant')
                        expected_inputs[i + j][1] = np.pad(
                            expected_inputs[i + j][1], [(0, pad_amt), (0, 0)],
                            'constant')
                    for exp_input, input_ in zip(expected_inputs[i + j],
                                                 inputs):
                        self.assertAllEqual(np.squeeze(exp_input),
                                            np.squeeze(input_[j]))

            with self.assertRaisesOpError('End of sequence'):
                _ = sess.run(next_record)
Exemple #11
0
def _get_data(examples, preprocess_examples, params,
              is_training, shuffle_examples=None, skip_n_initial_records=0,
              semisupervised_configs=None):
  """Gets transcription data."""
  return data.provide_batch(
      examples=examples,
      preprocess_examples=preprocess_examples,
      hparams=params,
      is_training=is_training,
      semisupervised_configs=semisupervised_configs,
      shuffle_examples=shuffle_examples,
      skip_n_initial_records=skip_n_initial_records)
Exemple #12
0
    def _ValidateProvideBatch(self,
                              examples,
                              truncated_length,
                              batch_size,
                              expected_inputs,
                              crop_training_sequence_to_notes=False):
        """Tests for correctness of batches."""
        hparams = copy.deepcopy(constants.DEFAULT_HPARAMS)
        hparams.crop_training_sequence_to_notes = crop_training_sequence_to_notes

        with self.test_session() as sess:
            batch, _ = data.provide_batch(batch_size=batch_size,
                                          examples=examples,
                                          hparams=hparams,
                                          truncated_length=truncated_length,
                                          is_training=False)
            sess.run(tf.local_variables_initializer())
            input_tensors = [
                batch.spec, batch.labels, batch.lengths, batch.filenames,
                batch.max_length
            ]
            self.assertEqual(
                len(expected_inputs) // batch_size, batch.num_batches)
            for i in range(0, batch.num_batches * batch_size, batch_size):
                # Wait to ensure example is pre-processed.
                time.sleep(0.1)
                inputs = sess.run(input_tensors)
                max_length = np.max(inputs[2])
                self.assertEqual(inputs[4], max_length)
                inputs = inputs[0:-1]
                for j in range(batch_size):
                    # Add batch padding if needed.
                    input_length = expected_inputs[i + j][2]
                    if input_length < max_length:
                        expected_inputs[i + j] = list(expected_inputs[i + j])
                        pad_amt = max_length - input_length
                        expected_inputs[i + j][0] = np.pad(
                            expected_inputs[i + j][0], [(0, pad_amt), (0, 0)],
                            'constant')
                        expected_inputs[i + j][1] = np.pad(
                            expected_inputs[i + j][1], [(0, pad_amt), (0, 0)],
                            'constant')
                    for exp_input, input_ in zip(expected_inputs[i + j],
                                                 inputs):
                        self.assertAllEqual(np.squeeze(exp_input),
                                            np.squeeze(input_[j]))

            with self.assertRaisesOpError('End of sequence'):
                _ = sess.run(input_tensors)
Exemple #13
0
def _get_data(examples,
              preprocess_examples,
              params,
              is_training,
              shuffle_examples=None,
              skip_n_initial_records=0,
              semisupervised_configs=None):
    """Gets transcription data."""
    return data.provide_batch(examples=examples,
                              preprocess_examples=preprocess_examples,
                              hparams=params,
                              is_training=is_training,
                              semisupervised_configs=semisupervised_configs,
                              shuffle_examples=shuffle_examples,
                              skip_n_initial_records=skip_n_initial_records)
Exemple #14
0
  def _ValidateProvideBatch(self,
                            examples,
                            truncated_length,
                            batch_size,
                            expected_inputs,
                            feed_dict=None):
    """Tests for correctness of batches."""
    hparams = copy.deepcopy(configs.DEFAULT_HPARAMS)
    hparams.batch_size = batch_size
    hparams.truncated_length_secs = (
        truncated_length / data.hparams_frames_per_second(hparams))

    with self.test_session() as sess:
      dataset = data.provide_batch(
          examples=examples,
          preprocess_examples=True,
          hparams=hparams,
          is_training=False)
      iterator = dataset.make_initializable_iterator()
      next_record = iterator.get_next()
      sess.run([
          tf.initializers.local_variables(),
          tf.initializers.global_variables(),
          iterator.initializer
      ], feed_dict=feed_dict)
      for i in range(0, len(expected_inputs), batch_size):
        # Wait to ensure example is pre-processed.
        time.sleep(0.1)
        features, labels = sess.run(next_record)
        inputs = [
            features.spec, labels.labels, features.length, features.sequence_id]
        max_length = np.max(inputs[2])
        for j in range(batch_size):
          # Add batch padding if needed.
          input_length = expected_inputs[i + j][2]
          if input_length < max_length:
            expected_inputs[i + j] = list(expected_inputs[i + j])
            pad_amt = max_length - input_length
            expected_inputs[i + j][0] = np.pad(
                expected_inputs[i + j][0], [(0, pad_amt), (0, 0)], 'constant')
            expected_inputs[i + j][1] = np.pad(
                expected_inputs[i + j][1],
                [(0, pad_amt), (0, 0)], 'constant')
          for exp_input, input_ in zip(expected_inputs[i + j], inputs):
            self.assertAllEqual(np.squeeze(exp_input), np.squeeze(input_[j]))

      with self.assertRaisesOpError('End of sequence'):
        _ = sess.run(next_record)
Exemple #15
0
    def validateProvideBatch(self, examples_path, truncated_length, batch_size,
                             expected_inputs):
        """Tests for correctness of batches."""
        hparams = copy.deepcopy(constants.DEFAULT_HPARAMS)

        with self.test_session() as sess:
            batch = data.provide_batch(batch_size=batch_size,
                                       examples_path=examples_path,
                                       hparams=hparams,
                                       truncated_length=truncated_length,
                                       is_training=False,
                                       batch_threads=1)
            sess.run(tf.local_variables_initializer())
            input_tensors = [
                batch.spec, batch.labels, batch.lengths, batch.filenames
            ]
            self.assertEqual(
                len(expected_inputs) // batch_size, batch.num_batches)
            with tf.contrib.slim.queues.QueueRunners(sess):
                for i in range(0, batch.num_batches * batch_size, batch_size):
                    # Wait to ensure example is pre-processed.
                    time.sleep(0.1)
                    inputs = sess.run(input_tensors)
                    max_length = np.max(inputs[2])
                    for j in range(batch_size):
                        # Add batch padding if needed.
                        input_length = expected_inputs[i + j][2]
                        if input_length < max_length:
                            expected_inputs[i + j] = list(expected_inputs[i +
                                                                          j])
                            pad_amt = max_length - input_length
                            expected_inputs[i + j][0] = np.pad(
                                expected_inputs[i + j][0],
                                [(0, pad_amt), (0, 0)], 'constant')
                            expected_inputs[i + j][1] = np.pad(
                                expected_inputs[i + j][1],
                                [(0, pad_amt), (0, 0)], 'constant')
                        for exp_input, input_ in zip(expected_inputs[i + j],
                                                     inputs):
                            self.assertAllEqual(np.squeeze(exp_input),
                                                np.squeeze(input_[j]))

                with self.assertRaisesOpError(
                        'is closed and has insufficient elements '
                        '\\(requested %d, current size %d\\)' %
                    (batch_size, len(expected_inputs) % batch_size)):
                    _ = sess.run(input_tensors)
Exemple #16
0
  def _ValidateProvideBatch(self,
                            examples,
                            truncated_length,
                            batch_size,
                            expected_inputs,
                            crop_training_sequence_to_notes=False):
    """Tests for correctness of batches."""
    hparams = copy.deepcopy(constants.DEFAULT_HPARAMS)
    hparams.crop_training_sequence_to_notes = crop_training_sequence_to_notes

    with self.test_session() as sess:
      batch, _ = data.provide_batch(
          batch_size=batch_size,
          examples=examples,
          hparams=hparams,
          truncated_length=truncated_length,
          is_training=False)
      sess.run(tf.local_variables_initializer())
      input_tensors = [
          batch.spec, batch.labels, batch.lengths, batch.filenames,
          batch.max_length
      ]
      self.assertEqual(len(expected_inputs) // batch_size, batch.num_batches)
      for i in range(0, batch.num_batches * batch_size, batch_size):
        # Wait to ensure example is pre-processed.
        time.sleep(0.1)
        inputs = sess.run(input_tensors)
        max_length = np.max(inputs[2])
        self.assertEqual(inputs[4], max_length)
        inputs = inputs[0:-1]
        for j in range(batch_size):
          # Add batch padding if needed.
          input_length = expected_inputs[i + j][2]
          if input_length < max_length:
            expected_inputs[i + j] = list(expected_inputs[i + j])
            pad_amt = max_length - input_length
            expected_inputs[i + j][0] = np.pad(
                expected_inputs[i + j][0], [(0, pad_amt), (0, 0)], 'constant')
            expected_inputs[i + j][1] = np.pad(
                expected_inputs[i + j][1],
                [(0, pad_amt), (0, 0)], 'constant')
          for exp_input, input_ in zip(expected_inputs[i + j], inputs):
            self.assertAllEqual(np.squeeze(exp_input), np.squeeze(input_[j]))

      with self.assertRaisesOpError('End of sequence'):
        _ = sess.run(input_tensors)
def transcribe(audio, sr, cuda=False):
    """
    Google sucks and want to use audio path (raw wav) instead of decoded
    samples loosing in decoupling between file format and DSP

    input audio and sample rate, output mat like asmd with (pitch, ons, offs, velocity)
    """

    # simple hack because google sucks... in this way we can accept audio data
    # already loaded and keep our reasonable interface (and decouple i/o
    # from processing)
    original_google_sucks = audio_io.wav_data_to_samples
    audio_io.wav_data_to_samples = google_sucks
    audio = np.array(audio)
    config = configs.CONFIG_MAP['onsets_frames']
    hparams = config.hparams
    hparams.use_cudnn = cuda
    hparams.batch_size = 1
    examples = tf.placeholder(tf.string, [None])

    dataset = data.provide_batch(examples=examples,
                                 preprocess_examples=True,
                                 params=hparams,
                                 is_training=False,
                                 shuffle_examples=False,
                                 skip_n_initial_records=0)

    estimator = train_util.create_estimator(config.model_fn, CHECKPOINT_DIR,
                                            hparams)

    iterator = dataset.make_initializable_iterator()
    next_record = iterator.get_next()

    example_list = list(
        audio_label_data_utils.process_record(wav_data=audio,
                                              sample_rate=sr,
                                              ns=music_pb2.NoteSequence(),
                                              example_id="fakeid",
                                              min_length=0,
                                              max_length=-1,
                                              allow_empty_notesequence=True,
                                              load_audio_with_librosa=False))
    assert len(example_list) == 1
    to_process = [example_list[0].SerializeToString()]

    sess = tf.Session()

    sess.run([
        tf.initializers.global_variables(),
        tf.initializers.local_variables()
    ])

    sess.run(iterator.initializer, {examples: to_process})

    def transcription_data(params):
        del params
        return tf.data.Dataset.from_tensors(sess.run(next_record))

    # put back the original function (it still writes and reload... stupid
    # though
    audio_io.wav_data_to_samples = original_google_sucks
    input_fn = infer_util.labels_to_features_wrapper(transcription_data)

    prediction_list = list(
        estimator.predict(input_fn, yield_single_examples=False))

    assert len(prediction_list) == 1

    notes = music_pb2.NoteSequence.FromString(
        prediction_list[0]['sequence_predictions'][0]).notes

    out = np.empty((len(notes), 4))
    for i, note in enumerate(notes):
        out[i] = [note.pitch, note.start_time, note.end_time, note.velocity]
    return out
Exemple #18
0
def _get_data(examples, preprocess_examples, params, is_training):
    """Gets transcription data."""
    return data.provide_batch(examples=examples,
                              preprocess_examples=preprocess_examples,
                              hparams=params,
                              is_training=is_training)
Exemple #19
0
# Checkpoint 경로 설정
CHECKPOINT_DIR = '../train'

# 하이퍼 파라미터 설정
config = configs.CONFIG_MAP['onsets_frames']
hparams = config.hparams
hparams.use_cudnn = False
hparams.batch_size = 1

# Placeholder 설정
examples = tf.placeholder(tf.string, [None])

# 배치 생성
dataset = data.provide_batch(examples=examples, preprocess_examples=True,
    params=hparams, is_training=False, shuffle_examples=False,
    skip_n_initial_records=0)

# Estimator 생성
estimator = train_util.create_estimator(
    config.model_fn, CHECKPOINT_DIR, hparams)

# Iterator 생성
iterator = dataset.make_initializable_iterator()
next_record = iterator.get_next()

# 세션 생성
sess = tf.Session()

# 세션 초기화
sess.run([tf.initializers.global_variables(), tf.initializers.local_variables()])
def model_inference(acoustic_checkpoint, hparams, examples_path, run_dir):
  """Runs inference for the given examples."""
  tf.logging.info('acoustic_checkpoint=%s', acoustic_checkpoint)
  tf.logging.info('examples_path=%s', examples_path)
  tf.logging.info('run_dir=%s', run_dir)

  with tf.Graph().as_default():
    num_dims = constants.MIDI_PITCHES

    # Build the acoustic model within an 'acoustic' scope to isolate its
    # variables from the other models.
    with tf.variable_scope('acoustic'):
      truncated_length = 0
      if FLAGS.max_seconds_per_sequence:
        truncated_length = int(
            math.ceil((FLAGS.max_seconds_per_sequence *
                       data.hparams_frames_per_second(hparams))))
      acoustic_data_provider, _ = data.provide_batch(
          batch_size=1,
          examples=examples_path,
          hparams=hparams,
          is_training=False,
          truncated_length=truncated_length,
          include_note_sequences=True)

      _, _, data_labels, _, _ = model.get_model(
          acoustic_data_provider, hparams, is_training=False)

    # The checkpoints won't have the new scopes.
    acoustic_variables = {
        re.sub(r'^acoustic/', '', var.op.name): var
        for var in slim.get_variables(scope='acoustic/')
    }
    acoustic_restore = tf.train.Saver(acoustic_variables)

    onset_probs_flat = tf.get_default_graph().get_tensor_by_name(
        'acoustic/onsets/onset_probs_flat:0')
    frame_probs_flat = tf.get_default_graph().get_tensor_by_name(
        'acoustic/frame_probs_flat:0')
    offset_probs_flat = tf.get_default_graph().get_tensor_by_name(
        'acoustic/offsets/offset_probs_flat:0')
    velocity_values_flat = tf.get_default_graph().get_tensor_by_name(
        'acoustic/velocity/velocity_values_flat:0')

    # Define some metrics.
    (metrics_to_updates, metric_note_precision, metric_note_recall,
     metric_note_f1, metric_note_precision_with_offsets,
     metric_note_recall_with_offsets, metric_note_f1_with_offsets,
     metric_note_precision_with_offsets_velocity,
     metric_note_recall_with_offsets_velocity,
     metric_note_f1_with_offsets_velocity, metric_frame_labels,
     metric_frame_predictions) = infer_util.define_metrics(num_dims)

    summary_op = tf.summary.merge_all()
    global_step = tf.contrib.framework.get_or_create_global_step()
    global_step_increment = global_step.assign_add(1)

    # Use a custom init function to restore the acoustic and language models
    # from their separate checkpoints.
    def init_fn(unused_self, sess):
      acoustic_restore.restore(sess, acoustic_checkpoint)

    scaffold = tf.train.Scaffold(init_fn=init_fn)
    session_creator = tf.train.ChiefSessionCreator(
        scaffold=scaffold, master=FLAGS.master)
    with tf.train.MonitoredSession(session_creator=session_creator) as sess:
      tf.logging.info('running session')
      summary_writer = tf.summary.FileWriter(
          logdir=run_dir, graph=sess.graph)

      tf.logging.info('Inferring for %d batches',
                      acoustic_data_provider.num_batches)
      infer_times = []
      num_frames = []
      for unused_i in range(acoustic_data_provider.num_batches):
        start_time = time.time()
        (labels, filenames, note_sequences, frame_probs, onset_probs,
         offset_probs, velocity_values) = sess.run([
             data_labels,
             acoustic_data_provider.filenames,
             acoustic_data_provider.note_sequences,
             frame_probs_flat,
             onset_probs_flat,
             offset_probs_flat,
             velocity_values_flat,
         ])
        # We expect these all to be length 1 because batch size is 1.
        assert len(filenames) == len(note_sequences) == 1
        # These should be the same length and have been flattened.
        assert len(labels) == len(frame_probs) == len(onset_probs)

        frame_predictions = frame_probs > FLAGS.frame_threshold
        if FLAGS.require_onset:
          onset_predictions = onset_probs > FLAGS.onset_threshold
        else:
          onset_predictions = None

        if FLAGS.use_offset:
          offset_predictions = offset_probs > FLAGS.offset_threshold
        else:
          offset_predictions = None

        sequence_prediction = sequences_lib.pianoroll_to_note_sequence(
            frame_predictions,
            frames_per_second=data.hparams_frames_per_second(hparams),
            min_duration_ms=0,
            min_midi_pitch=constants.MIN_MIDI_PITCH,
            onset_predictions=onset_predictions,
            offset_predictions=offset_predictions,
            velocity_values=velocity_values)

        end_time = time.time()
        infer_time = end_time - start_time
        infer_times.append(infer_time)
        num_frames.append(frame_probs.shape[0])
        tf.logging.info(
            'Infer time %f, frames %d, frames/sec %f, running average %f',
            infer_time, frame_probs.shape[0], frame_probs.shape[0] / infer_time,
            np.sum(num_frames) / np.sum(infer_times))

        tf.logging.info('Scoring sequence %s', filenames[0])

        def shift_notesequence(ns_time):
          return ns_time + hparams.backward_shift_amount_ms / 1000.

        sequence_label = infer_util.score_sequence(
            sess,
            global_step_increment,
            summary_op,
            summary_writer,
            metrics_to_updates,
            metric_note_precision,
            metric_note_recall,
            metric_note_f1,
            metric_note_precision_with_offsets,
            metric_note_recall_with_offsets,
            metric_note_f1_with_offsets,
            metric_note_precision_with_offsets_velocity,
            metric_note_recall_with_offsets_velocity,
            metric_note_f1_with_offsets_velocity,
            metric_frame_labels,
            metric_frame_predictions,
            frame_labels=labels,
            sequence_prediction=sequence_prediction,
            frames_per_second=data.hparams_frames_per_second(hparams),
            sequence_label=sequences_lib.adjust_notesequence_times(
                music_pb2.NoteSequence.FromString(note_sequences[0]),
                shift_notesequence)[0],
            sequence_id=filenames[0])

        # Make filenames UNIX-friendly.
        filename = filenames[0].decode('utf-8').replace('/', '_').replace(
            ':', '.')
        output_file = os.path.join(run_dir, filename + '.mid')
        tf.logging.info('Writing inferred midi file to %s', output_file)
        midi_io.sequence_proto_to_midi_file(sequence_prediction, output_file)

        label_output_file = os.path.join(run_dir, filename + '_label.mid')
        tf.logging.info('Writing label midi file to %s', label_output_file)
        midi_io.sequence_proto_to_midi_file(sequence_label, label_output_file)

        # Also write a pianoroll showing acoustic model output vs labels.
        pianoroll_output_file = os.path.join(run_dir,
                                             filename + '_pianoroll.png')
        tf.logging.info('Writing acoustic logit/label file to %s',
                        pianoroll_output_file)
        with tf.gfile.GFile(pianoroll_output_file, mode='w') as f:
          scipy.misc.imsave(
              f,
              infer_util.posterior_pianoroll_image(
                  frame_probs,
                  sequence_prediction,
                  labels,
                  overlap=True,
                  frames_per_second=data.hparams_frames_per_second(hparams)))

        summary_writer.flush()
Exemple #21
0
def model_inference(model_fn,
                    model_dir,
                    checkpoint_path,
                    hparams,
                    examples_path,
                    output_dir,
                    summary_writer,
                    master,
                    preprocess_examples,
                    write_summary_every_step=True):
  """Runs inference for the given examples."""
  tf.logging.info('model_dir=%s', model_dir)
  tf.logging.info('checkpoint_path=%s', checkpoint_path)
  tf.logging.info('examples_path=%s', examples_path)
  tf.logging.info('output_dir=%s', output_dir)

  estimator = train_util.create_estimator(
      model_fn, model_dir, hparams, master=master)

  with tf.Graph().as_default():
    num_dims = constants.MIDI_PITCHES

    dataset = data.provide_batch(
        examples=examples_path,
        preprocess_examples=preprocess_examples,
        hparams=hparams,
        is_training=False)

    # Define some metrics.
    (metrics_to_updates, metric_note_precision, metric_note_recall,
     metric_note_f1, metric_note_precision_with_offsets,
     metric_note_recall_with_offsets, metric_note_f1_with_offsets,
     metric_note_precision_with_offsets_velocity,
     metric_note_recall_with_offsets_velocity,
     metric_note_f1_with_offsets_velocity, metric_frame_labels,
     metric_frame_predictions) = infer_util.define_metrics(num_dims)

    summary_op = tf.summary.merge_all()

    if write_summary_every_step:
      global_step = tf.train.get_or_create_global_step()
      global_step_increment = global_step.assign_add(1)
    else:
      global_step = tf.constant(
          estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP))
      global_step_increment = global_step

    iterator = dataset.make_initializable_iterator()
    next_record = iterator.get_next()
    with tf.Session() as sess:
      sess.run([
          tf.initializers.global_variables(),
          tf.initializers.local_variables()
      ])

      infer_times = []
      num_frames = []

      sess.run(iterator.initializer)
      while True:
        try:
          record = sess.run(next_record)
        except tf.errors.OutOfRangeError:
          break

        def input_fn(params):
          del params
          return tf.data.Dataset.from_tensors(record)

        start_time = time.time()

        # TODO(fjord): This is a hack that allows us to keep using our existing
        # infer/scoring code with a tf.Estimator model. Ideally, we should
        # move things around so that we can use estimator.evaluate, which will
        # also be more efficient because it won't have to restore the checkpoint
        # for every example.
        prediction_list = list(
            estimator.predict(
                input_fn,
                checkpoint_path=checkpoint_path,
                yield_single_examples=False))
        assert len(prediction_list) == 1

        input_features = record[0]
        input_labels = record[1]

        filename = input_features.sequence_id[0]
        note_sequence = music_pb2.NoteSequence.FromString(
            input_labels.note_sequence[0])
        labels = input_labels.labels[0]
        frame_probs = prediction_list[0]['frame_probs'][0]
        frame_predictions = prediction_list[0]['frame_predictions'][0]
        onset_predictions = prediction_list[0]['onset_predictions'][0]
        velocity_values = prediction_list[0]['velocity_values'][0]
        offset_predictions = prediction_list[0]['offset_predictions'][0]

        if not FLAGS.require_onset:
          onset_predictions = None

        if not FLAGS.use_offset:
          offset_predictions = None

        sequence_prediction = sequences_lib.pianoroll_to_note_sequence(
            frame_predictions,
            frames_per_second=data.hparams_frames_per_second(hparams),
            min_duration_ms=0,
            min_midi_pitch=constants.MIN_MIDI_PITCH,
            onset_predictions=onset_predictions,
            offset_predictions=offset_predictions,
            velocity_values=velocity_values)

        end_time = time.time()
        infer_time = end_time - start_time
        infer_times.append(infer_time)
        num_frames.append(frame_predictions.shape[0])
        tf.logging.info(
            'Infer time %f, frames %d, frames/sec %f, running average %f',
            infer_time, frame_predictions.shape[0],
            frame_predictions.shape[0] / infer_time,
            np.sum(num_frames) / np.sum(infer_times))

        tf.logging.info('Scoring sequence %s', filename)

        def shift_notesequence(ns_time):
          return ns_time + hparams.backward_shift_amount_ms / 1000.

        sequence_label = sequences_lib.adjust_notesequence_times(
            note_sequence, shift_notesequence)[0]
        infer_util.score_sequence(
            sess,
            global_step_increment,
            metrics_to_updates,
            metric_note_precision,
            metric_note_recall,
            metric_note_f1,
            metric_note_precision_with_offsets,
            metric_note_recall_with_offsets,
            metric_note_f1_with_offsets,
            metric_note_precision_with_offsets_velocity,
            metric_note_recall_with_offsets_velocity,
            metric_note_f1_with_offsets_velocity,
            metric_frame_labels,
            metric_frame_predictions,
            frame_labels=labels,
            sequence_prediction=sequence_prediction,
            frames_per_second=data.hparams_frames_per_second(hparams),
            sequence_label=sequence_label,
            sequence_id=filename)

        if write_summary_every_step:
          # Make filenames UNIX-friendly.
          filename_safe = filename.decode('utf-8').replace('/', '_').replace(
              ':', '.')
          output_file = os.path.join(output_dir, filename_safe + '.mid')
          tf.logging.info('Writing inferred midi file to %s', output_file)
          midi_io.sequence_proto_to_midi_file(sequence_prediction, output_file)

          label_output_file = os.path.join(output_dir,
                                           filename_safe + '_label.mid')
          tf.logging.info('Writing label midi file to %s', label_output_file)
          midi_io.sequence_proto_to_midi_file(sequence_label, label_output_file)

          # Also write a pianoroll showing acoustic model output vs labels.
          pianoroll_output_file = os.path.join(output_dir,
                                               filename_safe + '_pianoroll.png')
          tf.logging.info('Writing acoustic logit/label file to %s',
                          pianoroll_output_file)
          with tf.gfile.GFile(pianoroll_output_file, mode='w') as f:
            scipy.misc.imsave(
                f,
                infer_util.posterior_pianoroll_image(
                    frame_probs,
                    sequence_prediction,
                    labels,
                    overlap=True,
                    frames_per_second=data.hparams_frames_per_second(hparams)))

          summary = sess.run(summary_op)
          summary_writer.add_summary(summary, sess.run(global_step))
          summary_writer.flush()

      if not write_summary_every_step:
        # Only write the summary variables for the final step.
        summary = sess.run(summary_op)
        summary_writer.add_summary(summary, sess.run(global_step))
        summary_writer.flush()
Exemple #22
0
from magenta.models.onsets_frames_transcription import data
from magenta.models.onsets_frames_transcription import split_audio_and_label_data
from magenta.models.onsets_frames_transcription import train_util
from magenta.music import midi_io
from magenta.protobuf import music_pb2
from magenta.music import sequences_lib

## Define model and load checkpoint
## Only needs to be run once.

config = configs.CONFIG_MAP['onsets_frames']
hparams = config.hparams
hparams.use_cudnn = False
hparams.batch_size = 1

examples = tf.placeholder(tf.string, [None])

dataset = data.provide_batch(
    examples=examples,
    preprocess_examples=True,
    hparams=hparams,
    is_training=False)

CHECKPOINT_DIR = '/Users/junhoyeo/Desktop/magenta-school-song/maestro-v1.0.0'
    # change to downloaded checkpoint path
estimator = train_util.create_estimator(
    config.model_fn, CHECKPOINT_DIR, hparams)

iterator = dataset.make_initializable_iterator()
next_record = iterator.get_next()
Exemple #23
0
CHECKPOINT_DIR = './train/train_50002'  ##todo

acoustic_checkpoint = tf.train.latest_checkpoint(CHECKPOINT_DIR)
print('acoustic_checkpoint=' + acoustic_checkpoint)
hparams = tf_utils.merge_hparams(constants.DEFAULT_HPARAMS,
                                 model.get_default_hparams())

with tf.Graph().as_default():
    examples = tf.placeholder(tf.string, [None])

    num_dims = constants.MIDI_PITCHES

    batch, iterator = data.provide_batch(batch_size=1,
                                         examples=examples,
                                         hparams=hparams,
                                         is_training=False,
                                         truncated_length=0)

    model.get_model(batch, hparams, is_training=False)

    session = tf.Session()
    saver = tf.train.Saver()
    saver.restore(session, acoustic_checkpoint)

    onset_probs_flat = tf.get_default_graph().get_tensor_by_name(
        'onsets/onset_probs_flat:0')
    frame_probs_flat = tf.get_default_graph().get_tensor_by_name(
        'frame_probs_flat:0')

# file = open('./data/wav_format/xinjing.wav','w')   ##todo
def main(argv):
  tf.logging.set_verbosity(FLAGS.log)

  config = configs.CONFIG_MAP[FLAGS.config]
  hparams = config.hparams
  # For this script, default to not using cudnn.
  hparams.use_cudnn = False
  hparams.parse(FLAGS.hparams)
  hparams.batch_size = 1
  hparams.truncated_length_secs = 0

  with tf.Graph().as_default():
    examples = tf.placeholder(tf.string, [None])

    dataset = data.provide_batch(
        examples=examples,
        preprocess_examples=True,
        hparams=hparams,
        is_training=False)

    estimator = train_util.create_estimator(config.model_fn,
                                            os.path.expanduser(FLAGS.model_dir),
                                            hparams)

    iterator = dataset.make_initializable_iterator()
    next_record = iterator.get_next()

    with tf.Session() as sess:
      sess.run([
          tf.initializers.global_variables(),
          tf.initializers.local_variables()
      ])

      for filename in argv[1:]:
        tf.logging.info('Starting transcription for %s...', filename)

        # The reason we bounce between two Dataset objects is so we can use
        # the data processing functionality in data.py without having to
        # construct all the Example protos in memory ahead of time or create
        # a temporary tfrecord file.
        tf.logging.info('Processing file...')
        sess.run(iterator.initializer, {examples: [create_example(filename)]})

        def input_fn(params):
          del params
          return tf.data.Dataset.from_tensors(sess.run(next_record))

        tf.logging.info('Running inference...')
        checkpoint_path = None
        if FLAGS.checkpoint_path:
          checkpoint_path = os.path.expanduser(FLAGS.checkpoint_path)
        prediction_list = list(
            estimator.predict(
                input_fn,
                checkpoint_path=checkpoint_path,
                yield_single_examples=False))
        assert len(prediction_list) == 1

        sequence_prediction = transcribe_audio(prediction_list[0], hparams)

        midi_filename = filename + '.midi'
        midi_io.sequence_proto_to_midi_file(sequence_prediction, midi_filename)

        tf.logging.info('Transcription written to %s.', midi_filename)
Exemple #25
0
def main(input, output):
    MAESTRO_CHECKPOINT_DIR = '/data/maestro/train'

    config = configs.CONFIG_MAP['onsets_frames']
    hparams = config.hparams
    hparams.use_cudnn = False
    hparams.batch_size = 1
    checkpoint_dir = MAESTRO_CHECKPOINT_DIR

    examples = tf.placeholder(tf.string, [None])

    dataset = data.provide_batch(
        examples=examples,
        preprocess_examples=True,
        params=hparams,
        is_training=False,
        shuffle_examples=False,
        skip_n_initial_records=0)

    estimator = train_util.create_estimator(
        config.model_fn, checkpoint_dir, hparams)

    iterator = dataset.make_initializable_iterator()
    next_record = iterator.get_next()

    to_process = []
    def process(files):
        for fn in files:
            print('**\n\n', fn, '\n\n**')
            with open(fn, 'rb', buffering=0) as f:
                wav_data = f.read()
            example_list = list(
                audio_label_data_utils.process_record(
                wav_data=wav_data,
                ns=music_pb2.NoteSequence(),
                example_id=fn,
                min_length=0,
                max_length=-1,
                allow_empty_notesequence=True))
            assert len(example_list) == 1
            to_process.append(example_list[0].SerializeToString())
            print('Processing complete for', fn)

            sess = tf.Session()

            sess.run([
                tf.initializers.global_variables(),
                tf.initializers.local_variables()
            ])

            sess.run(iterator.initializer, {examples: to_process})

            def transcription_data(params):
                del params
                return tf.data.Dataset.from_tensors(sess.run(next_record))


            input_fn = infer_util.labels_to_features_wrapper(transcription_data)

            #@title Run inference
            prediction_list = list(
                estimator.predict(
                    input_fn,
                    yield_single_examples=False))
            assert len(prediction_list) == 1

            # Ignore warnings caused by pyfluidsynth
            import warnings
            warnings.filterwarnings("ignore", category=DeprecationWarning) 

            sequence_prediction = music_pb2.NoteSequence.FromString(
                prediction_list[0]['sequence_predictions'][0])

            pathname = fn.split('/').pop()
            print('**\n\n', pathname, '\n\n**')
            midi_filename = '{outputs}/{file}.mid'.format(outputs=output,file=pathname)
            midi_io.sequence_proto_to_midi_file(sequence_prediction, midi_filename)

    files = ['{inputs}/{file}'.format(inputs=input, file=file) for file in os.listdir(input) if file.split('.').pop() == 'wav']
    print('the files', files)
    process(files)
def main(argv):
    tf.logging.set_verbosity(FLAGS.log)

    config = configs.CONFIG_MAP[FLAGS.config]
    hparams = config.hparams
    # For this script, default to not using cudnn.
    hparams.use_cudnn = False
    hparams.parse(FLAGS.hparams)
    hparams.batch_size = 1
    hparams.truncated_length_secs = 0

    with tf.Graph().as_default():
        examples = tf.placeholder(tf.string, [None])

        dataset = data.provide_batch(examples=examples,
                                     preprocess_examples=True,
                                     hparams=hparams,
                                     is_training=False)

        estimator = train_util.create_estimator(
            config.model_fn, os.path.expanduser(FLAGS.model_dir), hparams)

        iterator = dataset.make_initializable_iterator()
        next_record = iterator.get_next()

        with tf.Session() as sess:
            sess.run([
                tf.initializers.global_variables(),
                tf.initializers.local_variables()
            ])

            for filename in argv[1:]:
                tf.logging.info('Starting transcription for %s...', filename)

                # The reason we bounce between two Dataset objects is so we can use
                # the data processing functionality in data.py without having to
                # construct all the Example protos in memory ahead of time or create
                # a temporary tfrecord file.
                tf.logging.info('Processing file...')
                sess.run(iterator.initializer,
                         {examples: [create_example(filename)]})

                def input_fn(params):
                    del params
                    return tf.data.Dataset.from_tensors(sess.run(next_record))

                tf.logging.info('Running inference...')
                checkpoint_path = None
                if FLAGS.checkpoint_path:
                    checkpoint_path = os.path.expanduser(FLAGS.checkpoint_path)
                prediction_list = list(
                    estimator.predict(input_fn,
                                      checkpoint_path=checkpoint_path,
                                      yield_single_examples=False))
                assert len(prediction_list) == 1

                sequence_prediction = transcribe_audio(prediction_list[0],
                                                       hparams)

                midi_filename = filename + '.midi'
                midi_io.sequence_proto_to_midi_file(sequence_prediction,
                                                    midi_filename)

                tf.logging.info('Transcription written to %s.', midi_filename)
def model_inference(model_dir,
                    checkpoint_path,
                    hparams,
                    examples_path,
                    output_dir,
                    summary_writer,
                    write_summary_every_step=True):
    """Runs inference for the given examples."""
    tf.logging.info('model_dir=%s', model_dir)
    tf.logging.info('checkpoint_path=%s', checkpoint_path)
    tf.logging.info('examples_path=%s', examples_path)
    tf.logging.info('output_dir=%s', output_dir)

    estimator = train_util.create_estimator(model_dir, hparams)

    with tf.Graph().as_default():
        num_dims = constants.MIDI_PITCHES

        if FLAGS.max_seconds_per_sequence:
            truncated_length = int(
                math.ceil((FLAGS.max_seconds_per_sequence *
                           data.hparams_frames_per_second(hparams))))
        else:
            truncated_length = 0

        dataset = data.provide_batch(batch_size=1,
                                     examples=examples_path,
                                     hparams=hparams,
                                     is_training=False,
                                     truncated_length=truncated_length)

        # Define some metrics.
        (metrics_to_updates, metric_note_precision, metric_note_recall,
         metric_note_f1, metric_note_precision_with_offsets,
         metric_note_recall_with_offsets, metric_note_f1_with_offsets,
         metric_note_precision_with_offsets_velocity,
         metric_note_recall_with_offsets_velocity,
         metric_note_f1_with_offsets_velocity, metric_frame_labels,
         metric_frame_predictions) = infer_util.define_metrics(num_dims)

        summary_op = tf.summary.merge_all()

        if write_summary_every_step:
            global_step = tf.train.get_or_create_global_step()
            global_step_increment = global_step.assign_add(1)
        else:
            global_step = tf.constant(
                estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP))
            global_step_increment = global_step

        iterator = dataset.make_initializable_iterator()
        next_record = iterator.get_next()
        with tf.Session() as sess:
            sess.run([
                tf.initializers.global_variables(),
                tf.initializers.local_variables()
            ])

            infer_times = []
            num_frames = []

            sess.run(iterator.initializer)
            while True:
                try:
                    record = sess.run(next_record)
                except tf.errors.OutOfRangeError:
                    break

                def input_fn():
                    return tf.data.Dataset.from_tensors(record)

                start_time = time.time()

                # TODO(fjord): This is a hack that allows us to keep using our existing
                # infer/scoring code with a tf.Estimator model. Ideally, we should
                # move things around so that we can use estimator.evaluate, which will
                # also be more efficient because it won't have to restore the checkpoint
                # for every example.
                prediction_list = list(
                    estimator.predict(input_fn,
                                      checkpoint_path=checkpoint_path,
                                      yield_single_examples=False))
                assert len(prediction_list) == 1

                input_features = record[0]
                input_labels = record[1]

                filename = input_features.sequence_id[0]
                note_sequence = music_pb2.NoteSequence.FromString(
                    input_labels.note_sequence[0])
                labels = input_labels.labels[0]
                frame_probs = prediction_list[0]['frame_probs_flat']
                onset_probs = prediction_list[0]['onset_probs_flat']
                velocity_values = prediction_list[0]['velocity_values_flat']
                offset_probs = prediction_list[0]['offset_probs_flat']

                frame_predictions = frame_probs > FLAGS.frame_threshold
                if FLAGS.require_onset:
                    onset_predictions = onset_probs > FLAGS.onset_threshold
                else:
                    onset_predictions = None

                if FLAGS.use_offset:
                    offset_predictions = offset_probs > FLAGS.offset_threshold
                else:
                    offset_predictions = None

                sequence_prediction = sequences_lib.pianoroll_to_note_sequence(
                    frame_predictions,
                    frames_per_second=data.hparams_frames_per_second(hparams),
                    min_duration_ms=0,
                    min_midi_pitch=constants.MIN_MIDI_PITCH,
                    onset_predictions=onset_predictions,
                    offset_predictions=offset_predictions,
                    velocity_values=velocity_values)

                end_time = time.time()
                infer_time = end_time - start_time
                infer_times.append(infer_time)
                num_frames.append(frame_probs.shape[0])
                tf.logging.info(
                    'Infer time %f, frames %d, frames/sec %f, running average %f',
                    infer_time, frame_probs.shape[0],
                    frame_probs.shape[0] / infer_time,
                    np.sum(num_frames) / np.sum(infer_times))

                tf.logging.info('Scoring sequence %s', filename)

                def shift_notesequence(ns_time):
                    return ns_time + hparams.backward_shift_amount_ms / 1000.

                sequence_label = sequences_lib.adjust_notesequence_times(
                    note_sequence, shift_notesequence)[0]
                infer_util.score_sequence(
                    sess,
                    global_step_increment,
                    metrics_to_updates,
                    metric_note_precision,
                    metric_note_recall,
                    metric_note_f1,
                    metric_note_precision_with_offsets,
                    metric_note_recall_with_offsets,
                    metric_note_f1_with_offsets,
                    metric_note_precision_with_offsets_velocity,
                    metric_note_recall_with_offsets_velocity,
                    metric_note_f1_with_offsets_velocity,
                    metric_frame_labels,
                    metric_frame_predictions,
                    frame_labels=labels,
                    sequence_prediction=sequence_prediction,
                    frames_per_second=data.hparams_frames_per_second(hparams),
                    sequence_label=sequence_label,
                    sequence_id=filename)

                if write_summary_every_step:
                    # Make filenames UNIX-friendly.
                    filename_safe = filename.decode('utf-8').replace(
                        '/', '_').replace(':', '.')
                    output_file = os.path.join(output_dir,
                                               filename_safe + '.mid')
                    tf.logging.info('Writing inferred midi file to %s',
                                    output_file)
                    midi_io.sequence_proto_to_midi_file(
                        sequence_prediction, output_file)

                    label_output_file = os.path.join(
                        output_dir, filename_safe + '_label.mid')
                    tf.logging.info('Writing label midi file to %s',
                                    label_output_file)
                    midi_io.sequence_proto_to_midi_file(
                        sequence_label, label_output_file)

                    # Also write a pianoroll showing acoustic model output vs labels.
                    pianoroll_output_file = os.path.join(
                        output_dir, filename_safe + '_pianoroll.png')
                    tf.logging.info('Writing acoustic logit/label file to %s',
                                    pianoroll_output_file)
                    with tf.gfile.GFile(pianoroll_output_file, mode='w') as f:
                        scipy.misc.imsave(
                            f,
                            infer_util.posterior_pianoroll_image(
                                frame_probs,
                                sequence_prediction,
                                labels,
                                overlap=True,
                                frames_per_second=data.
                                hparams_frames_per_second(hparams)))

                    summary = sess.run(summary_op)
                    summary_writer.add_summary(summary, sess.run(global_step))
                    summary_writer.flush()

            if not write_summary_every_step:
                # Only write the summary variables for the final step.
                summary = sess.run(summary_op)
                summary_writer.add_summary(summary, sess.run(global_step))
                summary_writer.flush()