def pianoroll_summary(batch, step, name, frame_rate, pianoroll_key): """Plots ground truth pianoroll against predicted MIDI.""" batch_size = batch['note_active_velocities'].shape[0] for i in range(batch_size): gt_pianoroll = batch['note_active_velocities'][i] pred_pianoroll = batch[pianoroll_key][i] if isinstance(pred_pianoroll, note_seq.NoteSequence): pred_pianoroll = sequences_lib.sequence_to_pianoroll( pred_pianoroll, frames_per_second=frame_rate, min_pitch=note_seq.MIN_MIDI_PITCH, max_pitch=note_seq.MAX_MIDI_PITCH).active[:-1, :] img = np.zeros((gt_pianoroll.shape[1], gt_pianoroll.shape[0], 4)) # All values in `rgb` should be 0.0 except the value at index `idx` gt_color = {'idx': 1, 'rgb': np.array([0.0, 1.0, 0.0])} # green pred_color = {'idx': 2, 'rgb': np.array([0.0, 0.0, 1.0])} # blue gt_pianoroll_t = np.transpose(gt_pianoroll) pred_pianoroll_t = np.transpose(pred_pianoroll) img[:, :, gt_color['idx']] = gt_pianoroll_t img[:, :, pred_color['idx']] = pred_pianoroll_t # this is the alpha channel: img[:, :, 3] = np.logical_or(gt_pianoroll_t > 0.0, pred_pianoroll_t > 0.0) # Determine the min & max y-values for plotting. gt_note_indices = np.argmax(gt_pianoroll, axis=1) pred_note_indices = np.argmax(pred_pianoroll, axis=1) all_note_indices = np.concatenate([gt_note_indices, pred_note_indices]) if np.sum(np.nonzero(all_note_indices)) > 0: lower_limit = np.min(all_note_indices[np.nonzero(all_note_indices)]) upper_limit = np.max(all_note_indices) else: lower_limit = 0 upper_limit = 127 # Make the figures and add them to the summary. fig, ax, _ = pianoroll_plot_setup(figsize=(6.0, 4.0)) ax.imshow(img, origin='lower', aspect='auto') ax.set_ylim((max(lower_limit - 5, 0), min(upper_limit + 5, 127))) ax.yaxis.set_major_locator(MaxNLocator(integer=True)) labels_and_colors = [ ('GT MIDI', gt_color['rgb']), # green ('Pred MIDI', pred_color['rgb']), # blue ('Overlap', gt_color['rgb'] + pred_color['rgb']) # cyan ] patches = [mpatches.Patch(label=l, color=c) for l, c in labels_and_colors] fig.legend(handles=patches) fig_summary(f'pianoroll/{name}_{i + 1}', fig, step)
def _midiae_f0_helper(q_pitch, f0_midi, curve, i, step, label, tag): """Helper function to plot F0 info with MIDI AE.""" min_, max_ = _get_reasonable_f0_min_max(f0_midi) plt.close('all') fig, ax, sp = pianoroll_plot_setup(figsize=(6.0, 4.0)) sp.set_ylabel('MIDI Note Value') ax.step(q_pitch, 'r', linewidth=1.0, label='q_pitch') ax.plot(f0_midi, 'dodgerblue', linewidth=1.5, label='input f0') ax.plot(curve, 'darkgreen', linewidth=1.25, label=label) ax.set_ylim(min_, max_) ax.yaxis.set_major_locator(MaxNLocator(integer=True)) ax.legend() fig_summary(fig, step=step, tag=f'{tag}/ex_{i + 1}')