Example #1
0
def model_inference(model_fn,
                    model_dir,
                    checkpoint_path,
                    hparams,
                    examples_path,
                    output_dir,
                    summary_writer,
                    master,
                    preprocess_examples,
                    write_summary_every_step=True):
  """Runs inference for the given examples."""
  tf.logging.info('model_dir=%s', model_dir)
  tf.logging.info('checkpoint_path=%s', checkpoint_path)
  tf.logging.info('examples_path=%s', examples_path)
  tf.logging.info('output_dir=%s', output_dir)

  estimator = train_util.create_estimator(
      model_fn, model_dir, hparams, master=master)

  with tf.Graph().as_default():
    num_dims = constants.MIDI_PITCHES

    dataset = data.provide_batch(
        examples=examples_path,
        preprocess_examples=preprocess_examples,
        hparams=hparams,
        is_training=False)

    # Define some metrics.
    (metrics_to_updates, metric_note_precision, metric_note_recall,
     metric_note_f1, metric_note_precision_with_offsets,
     metric_note_recall_with_offsets, metric_note_f1_with_offsets,
     metric_note_precision_with_offsets_velocity,
     metric_note_recall_with_offsets_velocity,
     metric_note_f1_with_offsets_velocity, metric_frame_labels,
     metric_frame_predictions) = infer_util.define_metrics(num_dims)

    summary_op = tf.summary.merge_all()

    if write_summary_every_step:
      global_step = tf.train.get_or_create_global_step()
      global_step_increment = global_step.assign_add(1)
    else:
      global_step = tf.constant(
          estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP))
      global_step_increment = global_step

    iterator = dataset.make_initializable_iterator()
    next_record = iterator.get_next()
    with tf.Session() as sess:
      sess.run([
          tf.initializers.global_variables(),
          tf.initializers.local_variables()
      ])

      infer_times = []
      num_frames = []

      sess.run(iterator.initializer)
      while True:
        try:
          record = sess.run(next_record)
        except tf.errors.OutOfRangeError:
          break

        def input_fn(params):
          del params
          return tf.data.Dataset.from_tensors(record)

        start_time = time.time()

        # TODO(fjord): This is a hack that allows us to keep using our existing
        # infer/scoring code with a tf.Estimator model. Ideally, we should
        # move things around so that we can use estimator.evaluate, which will
        # also be more efficient because it won't have to restore the checkpoint
        # for every example.
        prediction_list = list(
            estimator.predict(
                input_fn,
                checkpoint_path=checkpoint_path,
                yield_single_examples=False))
        assert len(prediction_list) == 1

        input_features = record[0]
        input_labels = record[1]

        filename = input_features.sequence_id[0]
        note_sequence = music_pb2.NoteSequence.FromString(
            input_labels.note_sequence[0])
        labels = input_labels.labels[0]
        frame_probs = prediction_list[0]['frame_probs'][0]
        frame_predictions = prediction_list[0]['frame_predictions'][0]
        onset_predictions = prediction_list[0]['onset_predictions'][0]
        velocity_values = prediction_list[0]['velocity_values'][0]
        offset_predictions = prediction_list[0]['offset_predictions'][0]

        if not FLAGS.require_onset:
          onset_predictions = None

        if not FLAGS.use_offset:
          offset_predictions = None

        sequence_prediction = sequences_lib.pianoroll_to_note_sequence(
            frame_predictions,
            frames_per_second=data.hparams_frames_per_second(hparams),
            min_duration_ms=0,
            min_midi_pitch=constants.MIN_MIDI_PITCH,
            onset_predictions=onset_predictions,
            offset_predictions=offset_predictions,
            velocity_values=velocity_values)

        end_time = time.time()
        infer_time = end_time - start_time
        infer_times.append(infer_time)
        num_frames.append(frame_predictions.shape[0])
        tf.logging.info(
            'Infer time %f, frames %d, frames/sec %f, running average %f',
            infer_time, frame_predictions.shape[0],
            frame_predictions.shape[0] / infer_time,
            np.sum(num_frames) / np.sum(infer_times))

        tf.logging.info('Scoring sequence %s', filename)

        def shift_notesequence(ns_time):
          return ns_time + hparams.backward_shift_amount_ms / 1000.

        sequence_label = sequences_lib.adjust_notesequence_times(
            note_sequence, shift_notesequence)[0]
        infer_util.score_sequence(
            sess,
            global_step_increment,
            metrics_to_updates,
            metric_note_precision,
            metric_note_recall,
            metric_note_f1,
            metric_note_precision_with_offsets,
            metric_note_recall_with_offsets,
            metric_note_f1_with_offsets,
            metric_note_precision_with_offsets_velocity,
            metric_note_recall_with_offsets_velocity,
            metric_note_f1_with_offsets_velocity,
            metric_frame_labels,
            metric_frame_predictions,
            frame_labels=labels,
            sequence_prediction=sequence_prediction,
            frames_per_second=data.hparams_frames_per_second(hparams),
            sequence_label=sequence_label,
            sequence_id=filename)

        if write_summary_every_step:
          # Make filenames UNIX-friendly.
          filename_safe = filename.decode('utf-8').replace('/', '_').replace(
              ':', '.')
          output_file = os.path.join(output_dir, filename_safe + '.mid')
          tf.logging.info('Writing inferred midi file to %s', output_file)
          midi_io.sequence_proto_to_midi_file(sequence_prediction, output_file)

          label_output_file = os.path.join(output_dir,
                                           filename_safe + '_label.mid')
          tf.logging.info('Writing label midi file to %s', label_output_file)
          midi_io.sequence_proto_to_midi_file(sequence_label, label_output_file)

          # Also write a pianoroll showing acoustic model output vs labels.
          pianoroll_output_file = os.path.join(output_dir,
                                               filename_safe + '_pianoroll.png')
          tf.logging.info('Writing acoustic logit/label file to %s',
                          pianoroll_output_file)
          with tf.gfile.GFile(pianoroll_output_file, mode='w') as f:
            scipy.misc.imsave(
                f,
                infer_util.posterior_pianoroll_image(
                    frame_probs,
                    sequence_prediction,
                    labels,
                    overlap=True,
                    frames_per_second=data.hparams_frames_per_second(hparams)))

          summary = sess.run(summary_op)
          summary_writer.add_summary(summary, sess.run(global_step))
          summary_writer.flush()

      if not write_summary_every_step:
        # Only write the summary variables for the final step.
        summary = sess.run(summary_op)
        summary_writer.add_summary(summary, sess.run(global_step))
        summary_writer.flush()
def model_inference(acoustic_checkpoint, hparams, examples_path, run_dir):
  """Runs inference for the given examples."""
  tf.logging.info('acoustic_checkpoint=%s', acoustic_checkpoint)
  tf.logging.info('examples_path=%s', examples_path)
  tf.logging.info('run_dir=%s', run_dir)

  with tf.Graph().as_default():
    num_dims = constants.MIDI_PITCHES

    # Build the acoustic model within an 'acoustic' scope to isolate its
    # variables from the other models.
    with tf.variable_scope('acoustic'):
      truncated_length = 0
      if FLAGS.max_seconds_per_sequence:
        truncated_length = int(
            math.ceil((FLAGS.max_seconds_per_sequence *
                       data.hparams_frames_per_second(hparams))))
      acoustic_data_provider, _ = data.provide_batch(
          batch_size=1,
          examples=examples_path,
          hparams=hparams,
          is_training=False,
          truncated_length=truncated_length,
          include_note_sequences=True)

      _, _, data_labels, _, _ = model.get_model(
          acoustic_data_provider, hparams, is_training=False)

    # The checkpoints won't have the new scopes.
    acoustic_variables = {
        re.sub(r'^acoustic/', '', var.op.name): var
        for var in slim.get_variables(scope='acoustic/')
    }
    acoustic_restore = tf.train.Saver(acoustic_variables)

    onset_probs_flat = tf.get_default_graph().get_tensor_by_name(
        'acoustic/onsets/onset_probs_flat:0')
    frame_probs_flat = tf.get_default_graph().get_tensor_by_name(
        'acoustic/frame_probs_flat:0')
    offset_probs_flat = tf.get_default_graph().get_tensor_by_name(
        'acoustic/offsets/offset_probs_flat:0')
    velocity_values_flat = tf.get_default_graph().get_tensor_by_name(
        'acoustic/velocity/velocity_values_flat:0')

    # Define some metrics.
    (metrics_to_updates, metric_note_precision, metric_note_recall,
     metric_note_f1, metric_note_precision_with_offsets,
     metric_note_recall_with_offsets, metric_note_f1_with_offsets,
     metric_note_precision_with_offsets_velocity,
     metric_note_recall_with_offsets_velocity,
     metric_note_f1_with_offsets_velocity, metric_frame_labels,
     metric_frame_predictions) = infer_util.define_metrics(num_dims)

    summary_op = tf.summary.merge_all()
    global_step = tf.contrib.framework.get_or_create_global_step()
    global_step_increment = global_step.assign_add(1)

    # Use a custom init function to restore the acoustic and language models
    # from their separate checkpoints.
    def init_fn(unused_self, sess):
      acoustic_restore.restore(sess, acoustic_checkpoint)

    scaffold = tf.train.Scaffold(init_fn=init_fn)
    session_creator = tf.train.ChiefSessionCreator(
        scaffold=scaffold, master=FLAGS.master)
    with tf.train.MonitoredSession(session_creator=session_creator) as sess:
      tf.logging.info('running session')
      summary_writer = tf.summary.FileWriter(
          logdir=run_dir, graph=sess.graph)

      tf.logging.info('Inferring for %d batches',
                      acoustic_data_provider.num_batches)
      infer_times = []
      num_frames = []
      for unused_i in range(acoustic_data_provider.num_batches):
        start_time = time.time()
        (labels, filenames, note_sequences, frame_probs, onset_probs,
         offset_probs, velocity_values) = sess.run([
             data_labels,
             acoustic_data_provider.filenames,
             acoustic_data_provider.note_sequences,
             frame_probs_flat,
             onset_probs_flat,
             offset_probs_flat,
             velocity_values_flat,
         ])
        # We expect these all to be length 1 because batch size is 1.
        assert len(filenames) == len(note_sequences) == 1
        # These should be the same length and have been flattened.
        assert len(labels) == len(frame_probs) == len(onset_probs)

        frame_predictions = frame_probs > FLAGS.frame_threshold
        if FLAGS.require_onset:
          onset_predictions = onset_probs > FLAGS.onset_threshold
        else:
          onset_predictions = None

        if FLAGS.use_offset:
          offset_predictions = offset_probs > FLAGS.offset_threshold
        else:
          offset_predictions = None

        sequence_prediction = sequences_lib.pianoroll_to_note_sequence(
            frame_predictions,
            frames_per_second=data.hparams_frames_per_second(hparams),
            min_duration_ms=0,
            min_midi_pitch=constants.MIN_MIDI_PITCH,
            onset_predictions=onset_predictions,
            offset_predictions=offset_predictions,
            velocity_values=velocity_values)

        end_time = time.time()
        infer_time = end_time - start_time
        infer_times.append(infer_time)
        num_frames.append(frame_probs.shape[0])
        tf.logging.info(
            'Infer time %f, frames %d, frames/sec %f, running average %f',
            infer_time, frame_probs.shape[0], frame_probs.shape[0] / infer_time,
            np.sum(num_frames) / np.sum(infer_times))

        tf.logging.info('Scoring sequence %s', filenames[0])

        def shift_notesequence(ns_time):
          return ns_time + hparams.backward_shift_amount_ms / 1000.

        sequence_label = infer_util.score_sequence(
            sess,
            global_step_increment,
            summary_op,
            summary_writer,
            metrics_to_updates,
            metric_note_precision,
            metric_note_recall,
            metric_note_f1,
            metric_note_precision_with_offsets,
            metric_note_recall_with_offsets,
            metric_note_f1_with_offsets,
            metric_note_precision_with_offsets_velocity,
            metric_note_recall_with_offsets_velocity,
            metric_note_f1_with_offsets_velocity,
            metric_frame_labels,
            metric_frame_predictions,
            frame_labels=labels,
            sequence_prediction=sequence_prediction,
            frames_per_second=data.hparams_frames_per_second(hparams),
            sequence_label=sequences_lib.adjust_notesequence_times(
                music_pb2.NoteSequence.FromString(note_sequences[0]),
                shift_notesequence)[0],
            sequence_id=filenames[0])

        # Make filenames UNIX-friendly.
        filename = filenames[0].decode('utf-8').replace('/', '_').replace(
            ':', '.')
        output_file = os.path.join(run_dir, filename + '.mid')
        tf.logging.info('Writing inferred midi file to %s', output_file)
        midi_io.sequence_proto_to_midi_file(sequence_prediction, output_file)

        label_output_file = os.path.join(run_dir, filename + '_label.mid')
        tf.logging.info('Writing label midi file to %s', label_output_file)
        midi_io.sequence_proto_to_midi_file(sequence_label, label_output_file)

        # Also write a pianoroll showing acoustic model output vs labels.
        pianoroll_output_file = os.path.join(run_dir,
                                             filename + '_pianoroll.png')
        tf.logging.info('Writing acoustic logit/label file to %s',
                        pianoroll_output_file)
        with tf.gfile.GFile(pianoroll_output_file, mode='w') as f:
          scipy.misc.imsave(
              f,
              infer_util.posterior_pianoroll_image(
                  frame_probs,
                  sequence_prediction,
                  labels,
                  overlap=True,
                  frames_per_second=data.hparams_frames_per_second(hparams)))

        summary_writer.flush()
def model_inference(model_dir,
                    checkpoint_path,
                    hparams,
                    examples_path,
                    output_dir,
                    summary_writer,
                    write_summary_every_step=True):
    """Runs inference for the given examples."""
    tf.logging.info('model_dir=%s', model_dir)
    tf.logging.info('checkpoint_path=%s', checkpoint_path)
    tf.logging.info('examples_path=%s', examples_path)
    tf.logging.info('output_dir=%s', output_dir)

    estimator = train_util.create_estimator(model_dir, hparams)

    with tf.Graph().as_default():
        num_dims = constants.MIDI_PITCHES

        if FLAGS.max_seconds_per_sequence:
            truncated_length = int(
                math.ceil((FLAGS.max_seconds_per_sequence *
                           data.hparams_frames_per_second(hparams))))
        else:
            truncated_length = 0

        dataset = data.provide_batch(batch_size=1,
                                     examples=examples_path,
                                     hparams=hparams,
                                     is_training=False,
                                     truncated_length=truncated_length)

        # Define some metrics.
        (metrics_to_updates, metric_note_precision, metric_note_recall,
         metric_note_f1, metric_note_precision_with_offsets,
         metric_note_recall_with_offsets, metric_note_f1_with_offsets,
         metric_note_precision_with_offsets_velocity,
         metric_note_recall_with_offsets_velocity,
         metric_note_f1_with_offsets_velocity, metric_frame_labels,
         metric_frame_predictions) = infer_util.define_metrics(num_dims)

        summary_op = tf.summary.merge_all()

        if write_summary_every_step:
            global_step = tf.train.get_or_create_global_step()
            global_step_increment = global_step.assign_add(1)
        else:
            global_step = tf.constant(
                estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP))
            global_step_increment = global_step

        iterator = dataset.make_initializable_iterator()
        next_record = iterator.get_next()
        with tf.Session() as sess:
            sess.run([
                tf.initializers.global_variables(),
                tf.initializers.local_variables()
            ])

            infer_times = []
            num_frames = []

            sess.run(iterator.initializer)
            while True:
                try:
                    record = sess.run(next_record)
                except tf.errors.OutOfRangeError:
                    break

                def input_fn():
                    return tf.data.Dataset.from_tensors(record)

                start_time = time.time()

                # TODO(fjord): This is a hack that allows us to keep using our existing
                # infer/scoring code with a tf.Estimator model. Ideally, we should
                # move things around so that we can use estimator.evaluate, which will
                # also be more efficient because it won't have to restore the checkpoint
                # for every example.
                prediction_list = list(
                    estimator.predict(input_fn,
                                      checkpoint_path=checkpoint_path,
                                      yield_single_examples=False))
                assert len(prediction_list) == 1

                input_features = record[0]
                input_labels = record[1]

                filename = input_features.sequence_id[0]
                note_sequence = music_pb2.NoteSequence.FromString(
                    input_labels.note_sequence[0])
                labels = input_labels.labels[0]
                frame_probs = prediction_list[0]['frame_probs_flat']
                onset_probs = prediction_list[0]['onset_probs_flat']
                velocity_values = prediction_list[0]['velocity_values_flat']
                offset_probs = prediction_list[0]['offset_probs_flat']

                frame_predictions = frame_probs > FLAGS.frame_threshold
                if FLAGS.require_onset:
                    onset_predictions = onset_probs > FLAGS.onset_threshold
                else:
                    onset_predictions = None

                if FLAGS.use_offset:
                    offset_predictions = offset_probs > FLAGS.offset_threshold
                else:
                    offset_predictions = None

                sequence_prediction = sequences_lib.pianoroll_to_note_sequence(
                    frame_predictions,
                    frames_per_second=data.hparams_frames_per_second(hparams),
                    min_duration_ms=0,
                    min_midi_pitch=constants.MIN_MIDI_PITCH,
                    onset_predictions=onset_predictions,
                    offset_predictions=offset_predictions,
                    velocity_values=velocity_values)

                end_time = time.time()
                infer_time = end_time - start_time
                infer_times.append(infer_time)
                num_frames.append(frame_probs.shape[0])
                tf.logging.info(
                    'Infer time %f, frames %d, frames/sec %f, running average %f',
                    infer_time, frame_probs.shape[0],
                    frame_probs.shape[0] / infer_time,
                    np.sum(num_frames) / np.sum(infer_times))

                tf.logging.info('Scoring sequence %s', filename)

                def shift_notesequence(ns_time):
                    return ns_time + hparams.backward_shift_amount_ms / 1000.

                sequence_label = sequences_lib.adjust_notesequence_times(
                    note_sequence, shift_notesequence)[0]
                infer_util.score_sequence(
                    sess,
                    global_step_increment,
                    metrics_to_updates,
                    metric_note_precision,
                    metric_note_recall,
                    metric_note_f1,
                    metric_note_precision_with_offsets,
                    metric_note_recall_with_offsets,
                    metric_note_f1_with_offsets,
                    metric_note_precision_with_offsets_velocity,
                    metric_note_recall_with_offsets_velocity,
                    metric_note_f1_with_offsets_velocity,
                    metric_frame_labels,
                    metric_frame_predictions,
                    frame_labels=labels,
                    sequence_prediction=sequence_prediction,
                    frames_per_second=data.hparams_frames_per_second(hparams),
                    sequence_label=sequence_label,
                    sequence_id=filename)

                if write_summary_every_step:
                    # Make filenames UNIX-friendly.
                    filename_safe = filename.decode('utf-8').replace(
                        '/', '_').replace(':', '.')
                    output_file = os.path.join(output_dir,
                                               filename_safe + '.mid')
                    tf.logging.info('Writing inferred midi file to %s',
                                    output_file)
                    midi_io.sequence_proto_to_midi_file(
                        sequence_prediction, output_file)

                    label_output_file = os.path.join(
                        output_dir, filename_safe + '_label.mid')
                    tf.logging.info('Writing label midi file to %s',
                                    label_output_file)
                    midi_io.sequence_proto_to_midi_file(
                        sequence_label, label_output_file)

                    # Also write a pianoroll showing acoustic model output vs labels.
                    pianoroll_output_file = os.path.join(
                        output_dir, filename_safe + '_pianoroll.png')
                    tf.logging.info('Writing acoustic logit/label file to %s',
                                    pianoroll_output_file)
                    with tf.gfile.GFile(pianoroll_output_file, mode='w') as f:
                        scipy.misc.imsave(
                            f,
                            infer_util.posterior_pianoroll_image(
                                frame_probs,
                                sequence_prediction,
                                labels,
                                overlap=True,
                                frames_per_second=data.
                                hparams_frames_per_second(hparams)))

                    summary = sess.run(summary_op)
                    summary_writer.add_summary(summary, sess.run(global_step))
                    summary_writer.flush()

            if not write_summary_every_step:
                # Only write the summary variables for the final step.
                summary = sess.run(summary_op)
                summary_writer.add_summary(summary, sess.run(global_step))
                summary_writer.flush()