Example #1
0
def normalize_tempo(sequence, new_tempo=60):
    if math.isclose(sequence.total_time, 0.):
        return copy_lib.deepcopy(sequence)

    tempo_change_times, tempi = zip(*sorted(
        (tempo.time, tempo.qpm) for tempo in sequence.tempos
        if tempo.time < sequence.total_time))
    original_times = list(tempo_change_times) + [sequence.total_time]
    new_times = [original_times[0]]

    # Iterate through all the intervals between the tempo changes.
    # Compute a new duration for each of them.
    for start, end, tempo in zip(original_times[:-1], original_times[1:],
                                 tempi):
        time = (end - start) * tempo / new_tempo
        new_times.append(new_times[-1] + time)

    def time_func(t):
        return np.interp(t, original_times, new_times)

    adjusted_sequence, skipped_notes = sequences_lib.adjust_notesequence_times(
        sequence, time_func)
    if skipped_notes:
        warnings.warn(
            f'{skipped_notes} notes skipped in adjust_notesequence_times',
            RuntimeWarning)

    del adjusted_sequence.tempos[:]
    tempo = adjusted_sequence.tempos.add()
    tempo.time = 0.
    tempo.qpm = new_tempo

    return adjusted_sequence
Example #2
0
def align_cpp(samples,
              sample_rate,
              ns,
              cqt_hop_length,
              sf2_path,
              penalty_mul=1.0,
              band_radius_seconds=.5):
    """Aligns the notesequence to the wav file using C++ DTW.

  Args:
    samples: Samples to align.
    sample_rate: Sample rate for samples.
    ns: The source notesequence to align.
    cqt_hop_length: Hop length to use for CQT calculations.
    sf2_path: Path to SF2 file for synthesis.
    penalty_mul: Penalty multiplier to use for non-diagonal moves.
    band_radius_seconds: What size of band radius to use for restricting DTW.

  Raises:
    RuntimeError: If notes are skipped during alignment.

  Returns:
    samples: The samples used from the wav file.
    aligned_ns: The aligned version of the notesequence.
    remaining_ns: Any remaining notesequence that extended beyond the length
        of the wav file.
  """
    logging.info('Synthesizing')
    ns_samples = midi_synth.fluidsynth(ns,
                                       sf2_path=sf2_path,
                                       sample_rate=sample_rate).astype(
                                           np.float32)

    # It is critical that ns_samples and samples are the same length because the
    # alignment code does not do subsequence alignment.
    ns_samples = np.pad(ns_samples,
                        (0, max(0, samples.shape[0] - ns_samples.shape[0])),
                        'constant')

    # Pad samples too, if needed, because there are some cases where the
    # synthesized NoteSequence is actually longer.
    samples = np.pad(samples,
                     (0, max(0, ns_samples.shape[0] - samples.shape[0])),
                     'constant')

    # Note that we skip normalization here becasue it happens in C++.
    logging.info('source_cqt')
    source_cqt = extract_cqt(ns_samples, sample_rate, cqt_hop_length)

    logging.info('dest_cqt')
    dest_cqt = extract_cqt(samples, sample_rate, cqt_hop_length)

    alignment_task = alignment_pb2.AlignmentTask()
    alignment_task.sequence_1.x = source_cqt.shape[0]
    alignment_task.sequence_1.y = source_cqt.shape[1]
    for c in source_cqt.reshape([-1]):
        alignment_task.sequence_1.content.append(c)

    alignment_task.sequence_2.x = dest_cqt.shape[0]
    alignment_task.sequence_2.y = dest_cqt.shape[1]
    for c in dest_cqt.reshape([-1]):
        alignment_task.sequence_2.content.append(c)

    seconds_per_frame = cqt_hop_length / sample_rate

    alignment_task.band_radius = int(band_radius_seconds / seconds_per_frame)
    alignment_task.penalty = 0
    alignment_task.penalty_mul = penalty_mul

    # Write to file.
    fh, temp_path = tempfile.mkstemp(suffix='.proto')
    os.close(fh)
    with open(temp_path, 'w') as f:
        f.write(alignment_task.SerializeToString())

    # Align with C++ program.
    subprocess.check_call([ALIGN_BINARY, temp_path])

    # Read file.
    with open(temp_path + '.result') as f:
        result = alignment_pb2.AlignmentResult.FromString(f.read())

    # Clean up.
    os.remove(temp_path)
    os.remove(temp_path + '.result')

    logging.info('Aligning NoteSequence with warp path.')

    warp_seconds_i = np.array([i * seconds_per_frame for i in result.i])
    warp_seconds_j = np.array([j * seconds_per_frame for j in result.j])

    time_diffs = np.abs(warp_seconds_i - warp_seconds_j)
    warps = np.abs(time_diffs[1:] - time_diffs[:-1])

    stats = {
        'alignment_score': result.score,
        'warp_mean_s': np.mean(warps),
        'warp_median_s': np.median(warps),
        'warp_max_s': np.max(warps),
        'warp_min_s': np.min(warps),
        'time_diff_mean_s': np.mean(time_diffs),
        'time_diff_median_s': np.median(time_diffs),
        'time_diff_max_s': np.max(time_diffs),
        'time_diff_min_s': np.min(time_diffs),
    }

    for name, value in sorted(stats.iteritems()):
        logging.info('%s: %f', name, value)

    aligned_ns, skipped_notes = sequences_lib.adjust_notesequence_times(
        ns,
        lambda t: np.interp(t, warp_seconds_i, warp_seconds_j),
        minimum_duration=seconds_per_frame)
    if skipped_notes > 0:
        raise RuntimeError('Skipped {} notes'.format(skipped_notes))

    logging.debug('done')

    return aligned_ns, stats
def model_inference(acoustic_checkpoint, hparams, examples_path, run_dir):
  """Runs inference for the given examples."""
  tf.logging.info('acoustic_checkpoint=%s', acoustic_checkpoint)
  tf.logging.info('examples_path=%s', examples_path)
  tf.logging.info('run_dir=%s', run_dir)

  with tf.Graph().as_default():
    num_dims = constants.MIDI_PITCHES

    # Build the acoustic model within an 'acoustic' scope to isolate its
    # variables from the other models.
    with tf.variable_scope('acoustic'):
      truncated_length = 0
      if FLAGS.max_seconds_per_sequence:
        truncated_length = int(
            math.ceil((FLAGS.max_seconds_per_sequence *
                       data.hparams_frames_per_second(hparams))))
      acoustic_data_provider, _ = data.provide_batch(
          batch_size=1,
          examples=examples_path,
          hparams=hparams,
          is_training=False,
          truncated_length=truncated_length,
          include_note_sequences=True)

      _, _, data_labels, _, _ = model.get_model(
          acoustic_data_provider, hparams, is_training=False)

    # The checkpoints won't have the new scopes.
    acoustic_variables = {
        re.sub(r'^acoustic/', '', var.op.name): var
        for var in slim.get_variables(scope='acoustic/')
    }
    acoustic_restore = tf.train.Saver(acoustic_variables)

    onset_probs_flat = tf.get_default_graph().get_tensor_by_name(
        'acoustic/onsets/onset_probs_flat:0')
    frame_probs_flat = tf.get_default_graph().get_tensor_by_name(
        'acoustic/frame_probs_flat:0')
    offset_probs_flat = tf.get_default_graph().get_tensor_by_name(
        'acoustic/offsets/offset_probs_flat:0')
    velocity_values_flat = tf.get_default_graph().get_tensor_by_name(
        'acoustic/velocity/velocity_values_flat:0')

    # Define some metrics.
    (metrics_to_updates, metric_note_precision, metric_note_recall,
     metric_note_f1, metric_note_precision_with_offsets,
     metric_note_recall_with_offsets, metric_note_f1_with_offsets,
     metric_note_precision_with_offsets_velocity,
     metric_note_recall_with_offsets_velocity,
     metric_note_f1_with_offsets_velocity, metric_frame_labels,
     metric_frame_predictions) = infer_util.define_metrics(num_dims)

    summary_op = tf.summary.merge_all()
    global_step = tf.contrib.framework.get_or_create_global_step()
    global_step_increment = global_step.assign_add(1)

    # Use a custom init function to restore the acoustic and language models
    # from their separate checkpoints.
    def init_fn(unused_self, sess):
      acoustic_restore.restore(sess, acoustic_checkpoint)

    scaffold = tf.train.Scaffold(init_fn=init_fn)
    session_creator = tf.train.ChiefSessionCreator(
        scaffold=scaffold, master=FLAGS.master)
    with tf.train.MonitoredSession(session_creator=session_creator) as sess:
      tf.logging.info('running session')
      summary_writer = tf.summary.FileWriter(
          logdir=run_dir, graph=sess.graph)

      tf.logging.info('Inferring for %d batches',
                      acoustic_data_provider.num_batches)
      infer_times = []
      num_frames = []
      for unused_i in range(acoustic_data_provider.num_batches):
        start_time = time.time()
        (labels, filenames, note_sequences, frame_probs, onset_probs,
         offset_probs, velocity_values) = sess.run([
             data_labels,
             acoustic_data_provider.filenames,
             acoustic_data_provider.note_sequences,
             frame_probs_flat,
             onset_probs_flat,
             offset_probs_flat,
             velocity_values_flat,
         ])
        # We expect these all to be length 1 because batch size is 1.
        assert len(filenames) == len(note_sequences) == 1
        # These should be the same length and have been flattened.
        assert len(labels) == len(frame_probs) == len(onset_probs)

        frame_predictions = frame_probs > FLAGS.frame_threshold
        if FLAGS.require_onset:
          onset_predictions = onset_probs > FLAGS.onset_threshold
        else:
          onset_predictions = None

        if FLAGS.use_offset:
          offset_predictions = offset_probs > FLAGS.offset_threshold
        else:
          offset_predictions = None

        sequence_prediction = sequences_lib.pianoroll_to_note_sequence(
            frame_predictions,
            frames_per_second=data.hparams_frames_per_second(hparams),
            min_duration_ms=0,
            min_midi_pitch=constants.MIN_MIDI_PITCH,
            onset_predictions=onset_predictions,
            offset_predictions=offset_predictions,
            velocity_values=velocity_values)

        end_time = time.time()
        infer_time = end_time - start_time
        infer_times.append(infer_time)
        num_frames.append(frame_probs.shape[0])
        tf.logging.info(
            'Infer time %f, frames %d, frames/sec %f, running average %f',
            infer_time, frame_probs.shape[0], frame_probs.shape[0] / infer_time,
            np.sum(num_frames) / np.sum(infer_times))

        tf.logging.info('Scoring sequence %s', filenames[0])

        def shift_notesequence(ns_time):
          return ns_time + hparams.backward_shift_amount_ms / 1000.

        sequence_label = infer_util.score_sequence(
            sess,
            global_step_increment,
            summary_op,
            summary_writer,
            metrics_to_updates,
            metric_note_precision,
            metric_note_recall,
            metric_note_f1,
            metric_note_precision_with_offsets,
            metric_note_recall_with_offsets,
            metric_note_f1_with_offsets,
            metric_note_precision_with_offsets_velocity,
            metric_note_recall_with_offsets_velocity,
            metric_note_f1_with_offsets_velocity,
            metric_frame_labels,
            metric_frame_predictions,
            frame_labels=labels,
            sequence_prediction=sequence_prediction,
            frames_per_second=data.hparams_frames_per_second(hparams),
            sequence_label=sequences_lib.adjust_notesequence_times(
                music_pb2.NoteSequence.FromString(note_sequences[0]),
                shift_notesequence)[0],
            sequence_id=filenames[0])

        # Make filenames UNIX-friendly.
        filename = filenames[0].decode('utf-8').replace('/', '_').replace(
            ':', '.')
        output_file = os.path.join(run_dir, filename + '.mid')
        tf.logging.info('Writing inferred midi file to %s', output_file)
        midi_io.sequence_proto_to_midi_file(sequence_prediction, output_file)

        label_output_file = os.path.join(run_dir, filename + '_label.mid')
        tf.logging.info('Writing label midi file to %s', label_output_file)
        midi_io.sequence_proto_to_midi_file(sequence_label, label_output_file)

        # Also write a pianoroll showing acoustic model output vs labels.
        pianoroll_output_file = os.path.join(run_dir,
                                             filename + '_pianoroll.png')
        tf.logging.info('Writing acoustic logit/label file to %s',
                        pianoroll_output_file)
        with tf.gfile.GFile(pianoroll_output_file, mode='w') as f:
          scipy.misc.imsave(
              f,
              infer_util.posterior_pianoroll_image(
                  frame_probs,
                  sequence_prediction,
                  labels,
                  overlap=True,
                  frames_per_second=data.hparams_frames_per_second(hparams)))

        summary_writer.flush()
Example #4
0
def _calculate_metrics_py(frame_probs,
                          onset_probs,
                          frame_predictions,
                          onset_predictions,
                          offset_predictions,
                          velocity_values,
                          sequence_label_str,
                          frame_labels,
                          sequence_id,
                          hparams,
                          min_pitch,
                          max_pitch,
                          onsets_only,
                          restrict_to_pitch=None):
    """Python logic for calculating metrics on a single example."""
    tf.logging.info('Calculating metrics for %s with length %d', sequence_id,
                    frame_labels.shape[0])

    sequence_prediction = infer_util.predict_sequence(
        frame_probs=frame_probs,
        onset_probs=onset_probs,
        frame_predictions=frame_predictions,
        onset_predictions=onset_predictions,
        offset_predictions=offset_predictions,
        velocity_values=velocity_values,
        min_pitch=min_pitch,
        hparams=hparams,
        onsets_only=onsets_only)

    sequence_label = music_pb2.NoteSequence.FromString(sequence_label_str)

    if hparams.backward_shift_amount_ms:
        def shift_notesequence(ns_time):
            return ns_time + hparams.backward_shift_amount_ms / 1000.

        shifted_sequence_label, skipped_notes = (
            sequences_lib.adjust_notesequence_times(sequence_label,
                                                    shift_notesequence))
        assert skipped_notes == 0
        sequence_label = shifted_sequence_label

    est_intervals, est_pitches, est_velocities = (
        sequence_to_valued_intervals(
            sequence_prediction, restrict_to_pitch=restrict_to_pitch))

    ref_intervals, ref_pitches, ref_velocities = (
        sequence_to_valued_intervals(
            sequence_label, restrict_to_pitch=restrict_to_pitch))

    processed_frame_predictions = sequences_lib.sequence_to_pianoroll(
        sequence_prediction,
        frames_per_second=data.hparams_frames_per_second(hparams),
        min_pitch=min_pitch,
        max_pitch=max_pitch).active

    if processed_frame_predictions.shape[0] < frame_labels.shape[0]:
        # Pad transcribed frames with silence.
        pad_length = frame_labels.shape[0] - processed_frame_predictions.shape[0]
        processed_frame_predictions = np.pad(processed_frame_predictions,
                                             [(0, pad_length), (0, 0)], 'constant')
    elif processed_frame_predictions.shape[0] > frame_labels.shape[0]:
        # Truncate transcribed frames.
        processed_frame_predictions = (
            processed_frame_predictions[:frame_labels.shape[0], :])

    if len(ref_pitches) == 0:
        tf.logging.info(
            'Reference pitches were length 0, returning empty metrics for %s:',
            sequence_id)
        return tuple([[]] * 12 + [processed_frame_predictions])

    note_precision, note_recall, note_f1, _ = (
        mir_eval.transcription.precision_recall_f1_overlap(
            ref_intervals,
            pretty_midi.note_number_to_hz(ref_pitches),
            est_intervals,
            pretty_midi.note_number_to_hz(est_pitches),
            offset_ratio=None))

    (note_with_velocity_precision, note_with_velocity_recall,
     note_with_velocity_f1, _) = (
        mir_eval.transcription_velocity.precision_recall_f1_overlap(
            ref_intervals=ref_intervals,
            ref_pitches=pretty_midi.note_number_to_hz(ref_pitches),
            ref_velocities=ref_velocities,
            est_intervals=est_intervals,
            est_pitches=pretty_midi.note_number_to_hz(est_pitches),
            est_velocities=est_velocities,
            offset_ratio=None))

    (note_with_offsets_precision, note_with_offsets_recall, note_with_offsets_f1,
     _) = (
        mir_eval.transcription.precision_recall_f1_overlap(
            ref_intervals, pretty_midi.note_number_to_hz(ref_pitches),
            est_intervals, pretty_midi.note_number_to_hz(est_pitches)))

    (note_with_offsets_velocity_precision, note_with_offsets_velocity_recall,
     note_with_offsets_velocity_f1, _) = (
        mir_eval.transcription_velocity.precision_recall_f1_overlap(
            ref_intervals=ref_intervals,
            ref_pitches=pretty_midi.note_number_to_hz(ref_pitches),
            ref_velocities=ref_velocities,
            est_intervals=est_intervals,
            est_pitches=pretty_midi.note_number_to_hz(est_pitches),
            est_velocities=est_velocities))

    tf.logging.info(
        'Metrics for %s: Note F1 %f, Note w/ velocity F1 %f, Note w/ offsets F1 %f, '
        'Note w/ offsets & velocity: %f', sequence_id, note_f1,
        note_with_velocity_f1, note_with_offsets_f1,
        note_with_offsets_velocity_f1)
    # Return 1-d tensors for the metrics
    return ([note_precision], [note_recall], [note_f1],
            [note_with_velocity_precision], [note_with_velocity_recall],
            [note_with_velocity_f1], [note_with_offsets_precision],
            [note_with_offsets_recall], [note_with_offsets_f1
                                         ], [note_with_offsets_velocity_precision],
            [note_with_offsets_velocity_recall], [note_with_offsets_velocity_f1
                                                  ], [processed_frame_predictions])
Example #5
0
def model_inference(model_fn,
                    model_dir,
                    checkpoint_path,
                    hparams,
                    examples_path,
                    output_dir,
                    summary_writer,
                    master,
                    preprocess_examples,
                    write_summary_every_step=True):
  """Runs inference for the given examples."""
  tf.logging.info('model_dir=%s', model_dir)
  tf.logging.info('checkpoint_path=%s', checkpoint_path)
  tf.logging.info('examples_path=%s', examples_path)
  tf.logging.info('output_dir=%s', output_dir)

  estimator = train_util.create_estimator(
      model_fn, model_dir, hparams, master=master)

  with tf.Graph().as_default():
    num_dims = constants.MIDI_PITCHES

    dataset = data.provide_batch(
        examples=examples_path,
        preprocess_examples=preprocess_examples,
        hparams=hparams,
        is_training=False)

    # Define some metrics.
    (metrics_to_updates, metric_note_precision, metric_note_recall,
     metric_note_f1, metric_note_precision_with_offsets,
     metric_note_recall_with_offsets, metric_note_f1_with_offsets,
     metric_note_precision_with_offsets_velocity,
     metric_note_recall_with_offsets_velocity,
     metric_note_f1_with_offsets_velocity, metric_frame_labels,
     metric_frame_predictions) = infer_util.define_metrics(num_dims)

    summary_op = tf.summary.merge_all()

    if write_summary_every_step:
      global_step = tf.train.get_or_create_global_step()
      global_step_increment = global_step.assign_add(1)
    else:
      global_step = tf.constant(
          estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP))
      global_step_increment = global_step

    iterator = dataset.make_initializable_iterator()
    next_record = iterator.get_next()
    with tf.Session() as sess:
      sess.run([
          tf.initializers.global_variables(),
          tf.initializers.local_variables()
      ])

      infer_times = []
      num_frames = []

      sess.run(iterator.initializer)
      while True:
        try:
          record = sess.run(next_record)
        except tf.errors.OutOfRangeError:
          break

        def input_fn(params):
          del params
          return tf.data.Dataset.from_tensors(record)

        start_time = time.time()

        # TODO(fjord): This is a hack that allows us to keep using our existing
        # infer/scoring code with a tf.Estimator model. Ideally, we should
        # move things around so that we can use estimator.evaluate, which will
        # also be more efficient because it won't have to restore the checkpoint
        # for every example.
        prediction_list = list(
            estimator.predict(
                input_fn,
                checkpoint_path=checkpoint_path,
                yield_single_examples=False))
        assert len(prediction_list) == 1

        input_features = record[0]
        input_labels = record[1]

        filename = input_features.sequence_id[0]
        note_sequence = music_pb2.NoteSequence.FromString(
            input_labels.note_sequence[0])
        labels = input_labels.labels[0]
        frame_probs = prediction_list[0]['frame_probs'][0]
        frame_predictions = prediction_list[0]['frame_predictions'][0]
        onset_predictions = prediction_list[0]['onset_predictions'][0]
        velocity_values = prediction_list[0]['velocity_values'][0]
        offset_predictions = prediction_list[0]['offset_predictions'][0]

        if not FLAGS.require_onset:
          onset_predictions = None

        if not FLAGS.use_offset:
          offset_predictions = None

        sequence_prediction = sequences_lib.pianoroll_to_note_sequence(
            frame_predictions,
            frames_per_second=data.hparams_frames_per_second(hparams),
            min_duration_ms=0,
            min_midi_pitch=constants.MIN_MIDI_PITCH,
            onset_predictions=onset_predictions,
            offset_predictions=offset_predictions,
            velocity_values=velocity_values)

        end_time = time.time()
        infer_time = end_time - start_time
        infer_times.append(infer_time)
        num_frames.append(frame_predictions.shape[0])
        tf.logging.info(
            'Infer time %f, frames %d, frames/sec %f, running average %f',
            infer_time, frame_predictions.shape[0],
            frame_predictions.shape[0] / infer_time,
            np.sum(num_frames) / np.sum(infer_times))

        tf.logging.info('Scoring sequence %s', filename)

        def shift_notesequence(ns_time):
          return ns_time + hparams.backward_shift_amount_ms / 1000.

        sequence_label = sequences_lib.adjust_notesequence_times(
            note_sequence, shift_notesequence)[0]
        infer_util.score_sequence(
            sess,
            global_step_increment,
            metrics_to_updates,
            metric_note_precision,
            metric_note_recall,
            metric_note_f1,
            metric_note_precision_with_offsets,
            metric_note_recall_with_offsets,
            metric_note_f1_with_offsets,
            metric_note_precision_with_offsets_velocity,
            metric_note_recall_with_offsets_velocity,
            metric_note_f1_with_offsets_velocity,
            metric_frame_labels,
            metric_frame_predictions,
            frame_labels=labels,
            sequence_prediction=sequence_prediction,
            frames_per_second=data.hparams_frames_per_second(hparams),
            sequence_label=sequence_label,
            sequence_id=filename)

        if write_summary_every_step:
          # Make filenames UNIX-friendly.
          filename_safe = filename.decode('utf-8').replace('/', '_').replace(
              ':', '.')
          output_file = os.path.join(output_dir, filename_safe + '.mid')
          tf.logging.info('Writing inferred midi file to %s', output_file)
          midi_io.sequence_proto_to_midi_file(sequence_prediction, output_file)

          label_output_file = os.path.join(output_dir,
                                           filename_safe + '_label.mid')
          tf.logging.info('Writing label midi file to %s', label_output_file)
          midi_io.sequence_proto_to_midi_file(sequence_label, label_output_file)

          # Also write a pianoroll showing acoustic model output vs labels.
          pianoroll_output_file = os.path.join(output_dir,
                                               filename_safe + '_pianoroll.png')
          tf.logging.info('Writing acoustic logit/label file to %s',
                          pianoroll_output_file)
          with tf.gfile.GFile(pianoroll_output_file, mode='w') as f:
            scipy.misc.imsave(
                f,
                infer_util.posterior_pianoroll_image(
                    frame_probs,
                    sequence_prediction,
                    labels,
                    overlap=True,
                    frames_per_second=data.hparams_frames_per_second(hparams)))

          summary = sess.run(summary_op)
          summary_writer.add_summary(summary, sess.run(global_step))
          summary_writer.flush()

      if not write_summary_every_step:
        # Only write the summary variables for the final step.
        summary = sess.run(summary_op)
        summary_writer.add_summary(summary, sess.run(global_step))
        summary_writer.flush()
Example #6
0
def _calculate_metrics_py(
    frame_predictions, onset_predictions, offset_predictions, velocity_values,
    sequence_label_str, frame_labels, sequence_id, hparams):
  """Python logic for calculating metrics on a single example."""
  tf.logging.info('Calculating metrics for %s with length %d', sequence_id,
                  frame_labels.shape[0])
  if not hparams.predict_onset_threshold:
    onset_predictions = None
  if not hparams.predict_offset_threshold:
    offset_predictions = None

  sequence_prediction = sequences_lib.pianoroll_to_note_sequence(
      frames=frame_predictions,
      frames_per_second=data.hparams_frames_per_second(hparams),
      min_duration_ms=0,
      min_midi_pitch=constants.MIN_MIDI_PITCH,
      onset_predictions=onset_predictions,
      offset_predictions=offset_predictions,
      velocity_values=velocity_values)

  sequence_label = music_pb2.NoteSequence.FromString(sequence_label_str)

  if hparams.backward_shift_amount_ms:

    def shift_notesequence(ns_time):
      return ns_time + hparams.backward_shift_amount_ms / 1000.

    shifted_sequence_label, skipped_notes = (
        sequences_lib.adjust_notesequence_times(sequence_label,
                                                shift_notesequence))
    assert skipped_notes == 0
    sequence_label = shifted_sequence_label

  est_intervals, est_pitches, est_velocities = (
      infer_util.sequence_to_valued_intervals(sequence_prediction))

  ref_intervals, ref_pitches, ref_velocities = (
      infer_util.sequence_to_valued_intervals(sequence_label))

  note_precision, note_recall, note_f1, _ = (
      mir_eval.transcription.precision_recall_f1_overlap(
          ref_intervals,
          pretty_midi.note_number_to_hz(ref_pitches),
          est_intervals,
          pretty_midi.note_number_to_hz(est_pitches),
          offset_ratio=None))

  (note_with_offsets_precision, note_with_offsets_recall, note_with_offsets_f1,
   _) = (
       mir_eval.transcription.precision_recall_f1_overlap(
           ref_intervals, pretty_midi.note_number_to_hz(ref_pitches),
           est_intervals, pretty_midi.note_number_to_hz(est_pitches)))

  (note_with_offsets_velocity_precision, note_with_offsets_velocity_recall,
   note_with_offsets_velocity_f1, _) = (
       mir_eval.transcription_velocity.precision_recall_f1_overlap(
           ref_intervals=ref_intervals,
           ref_pitches=pretty_midi.note_number_to_hz(ref_pitches),
           ref_velocities=ref_velocities,
           est_intervals=est_intervals,
           est_pitches=pretty_midi.note_number_to_hz(est_pitches),
           est_velocities=est_velocities))

  processed_frame_predictions = sequences_lib.sequence_to_pianoroll(
      sequence_prediction,
      frames_per_second=data.hparams_frames_per_second(hparams),
      min_pitch=constants.MIN_MIDI_PITCH,
      max_pitch=constants.MAX_MIDI_PITCH).active

  if processed_frame_predictions.shape[0] < frame_labels.shape[0]:
    # Pad transcribed frames with silence.
    pad_length = frame_labels.shape[0] - processed_frame_predictions.shape[0]
    processed_frame_predictions = np.pad(processed_frame_predictions,
                                         [(0, pad_length), (0, 0)], 'constant')
  elif processed_frame_predictions.shape[0] > frame_labels.shape[0]:
    # Truncate transcribed frames.
    processed_frame_predictions = (
        processed_frame_predictions[:frame_labels.shape[0], :])

  tf.logging.info(
      'Metrics for %s: Note F1 %f, Note w/ offsets F1 %f, '
      'Note w/ offsets & velocity: %f', sequence_id, note_f1,
      note_with_offsets_f1, note_with_offsets_velocity_f1)
  return (note_precision, note_recall, note_f1, note_with_offsets_precision,
          note_with_offsets_recall, note_with_offsets_f1,
          note_with_offsets_velocity_precision,
          note_with_offsets_velocity_recall, note_with_offsets_velocity_f1,
          processed_frame_predictions)
def model_inference(model_dir,
                    checkpoint_path,
                    hparams,
                    examples_path,
                    output_dir,
                    summary_writer,
                    write_summary_every_step=True):
    """Runs inference for the given examples."""
    tf.logging.info('model_dir=%s', model_dir)
    tf.logging.info('checkpoint_path=%s', checkpoint_path)
    tf.logging.info('examples_path=%s', examples_path)
    tf.logging.info('output_dir=%s', output_dir)

    estimator = train_util.create_estimator(model_dir, hparams)

    with tf.Graph().as_default():
        num_dims = constants.MIDI_PITCHES

        if FLAGS.max_seconds_per_sequence:
            truncated_length = int(
                math.ceil((FLAGS.max_seconds_per_sequence *
                           data.hparams_frames_per_second(hparams))))
        else:
            truncated_length = 0

        dataset = data.provide_batch(batch_size=1,
                                     examples=examples_path,
                                     hparams=hparams,
                                     is_training=False,
                                     truncated_length=truncated_length)

        # Define some metrics.
        (metrics_to_updates, metric_note_precision, metric_note_recall,
         metric_note_f1, metric_note_precision_with_offsets,
         metric_note_recall_with_offsets, metric_note_f1_with_offsets,
         metric_note_precision_with_offsets_velocity,
         metric_note_recall_with_offsets_velocity,
         metric_note_f1_with_offsets_velocity, metric_frame_labels,
         metric_frame_predictions) = infer_util.define_metrics(num_dims)

        summary_op = tf.summary.merge_all()

        if write_summary_every_step:
            global_step = tf.train.get_or_create_global_step()
            global_step_increment = global_step.assign_add(1)
        else:
            global_step = tf.constant(
                estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP))
            global_step_increment = global_step

        iterator = dataset.make_initializable_iterator()
        next_record = iterator.get_next()
        with tf.Session() as sess:
            sess.run([
                tf.initializers.global_variables(),
                tf.initializers.local_variables()
            ])

            infer_times = []
            num_frames = []

            sess.run(iterator.initializer)
            while True:
                try:
                    record = sess.run(next_record)
                except tf.errors.OutOfRangeError:
                    break

                def input_fn():
                    return tf.data.Dataset.from_tensors(record)

                start_time = time.time()

                # TODO(fjord): This is a hack that allows us to keep using our existing
                # infer/scoring code with a tf.Estimator model. Ideally, we should
                # move things around so that we can use estimator.evaluate, which will
                # also be more efficient because it won't have to restore the checkpoint
                # for every example.
                prediction_list = list(
                    estimator.predict(input_fn,
                                      checkpoint_path=checkpoint_path,
                                      yield_single_examples=False))
                assert len(prediction_list) == 1

                input_features = record[0]
                input_labels = record[1]

                filename = input_features.sequence_id[0]
                note_sequence = music_pb2.NoteSequence.FromString(
                    input_labels.note_sequence[0])
                labels = input_labels.labels[0]
                frame_probs = prediction_list[0]['frame_probs_flat']
                onset_probs = prediction_list[0]['onset_probs_flat']
                velocity_values = prediction_list[0]['velocity_values_flat']
                offset_probs = prediction_list[0]['offset_probs_flat']

                frame_predictions = frame_probs > FLAGS.frame_threshold
                if FLAGS.require_onset:
                    onset_predictions = onset_probs > FLAGS.onset_threshold
                else:
                    onset_predictions = None

                if FLAGS.use_offset:
                    offset_predictions = offset_probs > FLAGS.offset_threshold
                else:
                    offset_predictions = None

                sequence_prediction = sequences_lib.pianoroll_to_note_sequence(
                    frame_predictions,
                    frames_per_second=data.hparams_frames_per_second(hparams),
                    min_duration_ms=0,
                    min_midi_pitch=constants.MIN_MIDI_PITCH,
                    onset_predictions=onset_predictions,
                    offset_predictions=offset_predictions,
                    velocity_values=velocity_values)

                end_time = time.time()
                infer_time = end_time - start_time
                infer_times.append(infer_time)
                num_frames.append(frame_probs.shape[0])
                tf.logging.info(
                    'Infer time %f, frames %d, frames/sec %f, running average %f',
                    infer_time, frame_probs.shape[0],
                    frame_probs.shape[0] / infer_time,
                    np.sum(num_frames) / np.sum(infer_times))

                tf.logging.info('Scoring sequence %s', filename)

                def shift_notesequence(ns_time):
                    return ns_time + hparams.backward_shift_amount_ms / 1000.

                sequence_label = sequences_lib.adjust_notesequence_times(
                    note_sequence, shift_notesequence)[0]
                infer_util.score_sequence(
                    sess,
                    global_step_increment,
                    metrics_to_updates,
                    metric_note_precision,
                    metric_note_recall,
                    metric_note_f1,
                    metric_note_precision_with_offsets,
                    metric_note_recall_with_offsets,
                    metric_note_f1_with_offsets,
                    metric_note_precision_with_offsets_velocity,
                    metric_note_recall_with_offsets_velocity,
                    metric_note_f1_with_offsets_velocity,
                    metric_frame_labels,
                    metric_frame_predictions,
                    frame_labels=labels,
                    sequence_prediction=sequence_prediction,
                    frames_per_second=data.hparams_frames_per_second(hparams),
                    sequence_label=sequence_label,
                    sequence_id=filename)

                if write_summary_every_step:
                    # Make filenames UNIX-friendly.
                    filename_safe = filename.decode('utf-8').replace(
                        '/', '_').replace(':', '.')
                    output_file = os.path.join(output_dir,
                                               filename_safe + '.mid')
                    tf.logging.info('Writing inferred midi file to %s',
                                    output_file)
                    midi_io.sequence_proto_to_midi_file(
                        sequence_prediction, output_file)

                    label_output_file = os.path.join(
                        output_dir, filename_safe + '_label.mid')
                    tf.logging.info('Writing label midi file to %s',
                                    label_output_file)
                    midi_io.sequence_proto_to_midi_file(
                        sequence_label, label_output_file)

                    # Also write a pianoroll showing acoustic model output vs labels.
                    pianoroll_output_file = os.path.join(
                        output_dir, filename_safe + '_pianoroll.png')
                    tf.logging.info('Writing acoustic logit/label file to %s',
                                    pianoroll_output_file)
                    with tf.gfile.GFile(pianoroll_output_file, mode='w') as f:
                        scipy.misc.imsave(
                            f,
                            infer_util.posterior_pianoroll_image(
                                frame_probs,
                                sequence_prediction,
                                labels,
                                overlap=True,
                                frames_per_second=data.
                                hparams_frames_per_second(hparams)))

                    summary = sess.run(summary_op)
                    summary_writer.add_summary(summary, sess.run(global_step))
                    summary_writer.flush()

            if not write_summary_every_step:
                # Only write the summary variables for the final step.
                summary = sess.run(summary_op)
                summary_writer.add_summary(summary, sess.run(global_step))
                summary_writer.flush()
Example #8
0
def _calculate_metrics_py(frame_predictions, onset_predictions,
                          offset_predictions, velocity_values,
                          sequence_label_str, frame_labels, sequence_id,
                          hparams):
    """Python logic for calculating metrics on a single example."""
    tf.logging.info('Calculating metrics for %s with length %d', sequence_id,
                    frame_labels.shape[0])
    if not hparams.predict_onset_threshold:
        onset_predictions = None
    if not hparams.predict_offset_threshold:
        offset_predictions = None

    sequence_prediction = sequences_lib.pianoroll_to_note_sequence(
        frames=frame_predictions,
        frames_per_second=data.hparams_frames_per_second(hparams),
        min_duration_ms=0,
        min_midi_pitch=constants.MIN_MIDI_PITCH,
        onset_predictions=onset_predictions,
        offset_predictions=offset_predictions,
        velocity_values=velocity_values)

    sequence_label = music_pb2.NoteSequence.FromString(sequence_label_str)

    if hparams.backward_shift_amount_ms:

        def shift_notesequence(ns_time):
            return ns_time + hparams.backward_shift_amount_ms / 1000.

        shifted_sequence_label, skipped_notes = (
            sequences_lib.adjust_notesequence_times(sequence_label,
                                                    shift_notesequence))
        assert skipped_notes == 0
        sequence_label = shifted_sequence_label

    est_intervals, est_pitches, est_velocities = (
        infer_util.sequence_to_valued_intervals(sequence_prediction))

    ref_intervals, ref_pitches, ref_velocities = (
        infer_util.sequence_to_valued_intervals(sequence_label))

    note_precision, note_recall, note_f1, _ = (
        mir_eval.transcription.precision_recall_f1_overlap(
            ref_intervals,
            pretty_midi.note_number_to_hz(ref_pitches),
            est_intervals,
            pretty_midi.note_number_to_hz(est_pitches),
            offset_ratio=None))

    (note_with_offsets_precision, note_with_offsets_recall,
     note_with_offsets_f1,
     _) = (mir_eval.transcription.precision_recall_f1_overlap(
         ref_intervals, pretty_midi.note_number_to_hz(ref_pitches),
         est_intervals, pretty_midi.note_number_to_hz(est_pitches)))

    (note_with_offsets_velocity_precision, note_with_offsets_velocity_recall,
     note_with_offsets_velocity_f1,
     _) = (mir_eval.transcription_velocity.precision_recall_f1_overlap(
         ref_intervals=ref_intervals,
         ref_pitches=pretty_midi.note_number_to_hz(ref_pitches),
         ref_velocities=ref_velocities,
         est_intervals=est_intervals,
         est_pitches=pretty_midi.note_number_to_hz(est_pitches),
         est_velocities=est_velocities))

    processed_frame_predictions = sequences_lib.sequence_to_pianoroll(
        sequence_prediction,
        frames_per_second=data.hparams_frames_per_second(hparams),
        min_pitch=constants.MIN_MIDI_PITCH,
        max_pitch=constants.MAX_MIDI_PITCH).active

    if processed_frame_predictions.shape[0] < frame_labels.shape[0]:
        # Pad transcribed frames with silence.
        pad_length = frame_labels.shape[0] - processed_frame_predictions.shape[
            0]
        processed_frame_predictions = np.pad(processed_frame_predictions,
                                             [(0, pad_length),
                                              (0, 0)], 'constant')
    elif processed_frame_predictions.shape[0] > frame_labels.shape[0]:
        # Truncate transcribed frames.
        processed_frame_predictions = (
            processed_frame_predictions[:frame_labels.shape[0], :])

    tf.logging.info(
        'Metrics for %s: Note F1 %f, Note w/ offsets F1 %f, '
        'Note w/ offsets & velocity: %f', sequence_id, note_f1,
        note_with_offsets_f1, note_with_offsets_velocity_f1)
    return (note_precision, note_recall, note_f1, note_with_offsets_precision,
            note_with_offsets_recall, note_with_offsets_f1,
            note_with_offsets_velocity_precision,
            note_with_offsets_velocity_recall, note_with_offsets_velocity_f1,
            processed_frame_predictions)