示例#1
0
def TrimPaddingAndPlotAttention(fig,
                                axes,
                                atten_matrix,
                                src_len,
                                tgt_len,
                                transcript=None,
                                **kwargs):
    """Trims axes of atten_matrix with shape (tgt_time, src_time) and plots it.

  For use as a plot function with MatplotlibFigureSummary.

  Args:
    fig:  A matplotlib figure handle.
    axes:  A matplotlib axes handle.
    atten_matrix:  A 2D ndarray shaped (tgt_time, src_time).
    src_len:  Integer length to use to trim the src_time axis of atten_matrix.
    tgt_len:  Integer length to use to trim the tgt_time axis of atten_matrix.
    transcript: transcript for the target sequence.
    **kwargs:  Additional keyword args to pass to plot.AddImage.
  """
    plot.AddImage(fig,
                  axes,
                  atten_matrix[:tgt_len, :src_len],
                  clim=(0, 1),
                  **kwargs)
    if transcript is not None:
        if isinstance(transcript, np.ndarray):
            transcript = ' '.join(transcript[:src_len])
        axes.set_xlabel(plot.ToUnicode(transcript), size='x-small', wrap=True)
示例#2
0
 def Draw(fig, axes, img, label, pred):
     plot.AddImage(fig=fig,
                   axes=axes,
                   data=img[:, :, 0] / 256.,
                   show_colorbar=False,
                   suppress_xticks=True,
                   suppress_yticks=True)
     axes.text(x=0.5,
               y=0,
               s=u'%d vs. %d' % (label, pred),
               transform=axes.transAxes,
               horizontalalignment='center')
示例#3
0
def TrimPaddingAndPlotSequence(fig, axes, seq_matrix, seq_len, **kwargs):
    """Trims the time axis of seq_matrix with shape (dim, time) and plots it.

  For use as a plot function with MatplotlibFigureSummary.

  Args:
    fig:  A matplotlib figure handle.
    axes:  A matplotlib axes handle.
    seq_matrix:  A 2D ndarray shaped (num_rows, time).
    seq_len:  Integer length to use to trim the time axis of seq_matrix.
    **kwargs:  Additional keyword args to pass to plot.AddImage.
  """
    plot.AddImage(fig, axes, seq_matrix[:, :seq_len], **kwargs)
示例#4
0
 def DrawCameraImage(fig, axes, frontal_image, run_segment_id):
     """Draw camera image for image summary."""
     plot.AddImage(fig=fig,
                   axes=axes,
                   data=frontal_image / 256.,
                   show_colorbar=False,
                   suppress_xticks=True,
                   suppress_yticks=True)
     txt = axes.text(x=0.5,
                     y=0.01,
                     s=run_segment_id,
                     color='blue',
                     fontsize=14,
                     transform=axes.transAxes,
                     horizontalalignment='center')
     txt.set_path_effects([
         path_effects.Stroke(linewidth=3, foreground='lightblue'),
         path_effects.Normal()
     ])
示例#5
0
 def TrimAndAddImage(fig, axes, data, trim, title, **kwargs):
     plot.AddImage(fig,
                   axes,
                   data[:trim[0], :trim[1]],
                   title=title,
                   **kwargs)
示例#6
0
 def PlotAttention(fig, axes, cur_atten_probs, title, set_x_label):
   plot.AddImage(fig, axes, cur_atten_probs, title=title)
   axes.set_ylabel(plot.ToUnicode('Output sequence index'), wrap=True)
   if set_x_label:
     axes.set_xlabel(plot.ToUnicode('Input sequence index'), wrap=True)