Example #1
0
def pianoroll_summary(batch, step, name, frame_rate, pianoroll_key):
  """Plots ground truth pianoroll against predicted MIDI."""
  batch_size = batch['note_active_velocities'].shape[0]
  for i in range(batch_size):
    gt_pianoroll = batch['note_active_velocities'][i]
    pred_pianoroll = batch[pianoroll_key][i]

    if isinstance(pred_pianoroll, note_seq.NoteSequence):
      pred_pianoroll = sequences_lib.sequence_to_pianoroll(
          pred_pianoroll,
          frames_per_second=frame_rate,
          min_pitch=note_seq.MIN_MIDI_PITCH,
          max_pitch=note_seq.MAX_MIDI_PITCH).active[:-1, :]
    img = np.zeros((gt_pianoroll.shape[1], gt_pianoroll.shape[0], 4))

    # All values in `rgb` should be 0.0 except the value at index `idx`
    gt_color = {'idx': 1, 'rgb': np.array([0.0, 1.0, 0.0])}  # green
    pred_color = {'idx': 2, 'rgb': np.array([0.0, 0.0, 1.0])}  # blue

    gt_pianoroll_t = np.transpose(gt_pianoroll)
    pred_pianoroll_t = np.transpose(pred_pianoroll)
    img[:, :, gt_color['idx']] = gt_pianoroll_t
    img[:, :, pred_color['idx']] = pred_pianoroll_t

    # this is the alpha channel:
    img[:, :, 3] = np.logical_or(gt_pianoroll_t > 0.0, pred_pianoroll_t > 0.0)

    # Determine the min & max y-values for plotting.
    gt_note_indices = np.argmax(gt_pianoroll, axis=1)
    pred_note_indices = np.argmax(pred_pianoroll, axis=1)
    all_note_indices = np.concatenate([gt_note_indices, pred_note_indices])

    if np.sum(np.nonzero(all_note_indices)) > 0:
      lower_limit = np.min(all_note_indices[np.nonzero(all_note_indices)])
      upper_limit = np.max(all_note_indices)
    else:
      lower_limit = 0
      upper_limit = 127

    # Make the figures and add them to the summary.
    fig, ax, _ = pianoroll_plot_setup(figsize=(6.0, 4.0))
    ax.imshow(img, origin='lower', aspect='auto')
    ax.set_ylim((max(lower_limit - 5, 0), min(upper_limit + 5, 127)))
    ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    labels_and_colors = [
        ('GT MIDI', gt_color['rgb']),  # green
        ('Pred MIDI', pred_color['rgb']),  # blue
        ('Overlap', gt_color['rgb'] + pred_color['rgb'])  # cyan
    ]
    patches = [mpatches.Patch(label=l, color=c) for l, c in labels_and_colors]
    fig.legend(handles=patches)
    fig_summary(f'pianoroll/{name}_{i + 1}', fig, step)
Example #2
0
def _midiae_f0_helper(q_pitch, f0_midi, curve, i, step, label, tag):
    """Helper function to plot F0 info with MIDI AE."""
    min_, max_ = _get_reasonable_f0_min_max(f0_midi)
    plt.close('all')
    fig, ax, sp = pianoroll_plot_setup(figsize=(6.0, 4.0))
    sp.set_ylabel('MIDI Note Value')

    ax.step(q_pitch, 'r', linewidth=1.0, label='q_pitch')
    ax.plot(f0_midi, 'dodgerblue', linewidth=1.5, label='input f0')
    ax.plot(curve, 'darkgreen', linewidth=1.25, label=label)

    ax.set_ylim(min_, max_)
    ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    ax.legend()
    fig_summary(fig, step=step, tag=f'{tag}/ex_{i + 1}')