Пример #1
0
def TrimPaddingAndPlotAttention(fig,
                                axes,
                                atten_matrix,
                                src_len,
                                tgt_len,
                                transcript=None,
                                **kwargs):
    """Trims axes of atten_matrix with shape (tgt_time, src_time) and plots it.

  For use as a plot function with MatplotlibFigureSummary.

  Args:
    fig:  A matplotlib figure handle.
    axes:  A matplotlib axes handle.
    atten_matrix:  A 2D ndarray shaped (tgt_time, src_time).
    src_len:  Integer length to use to trim the src_time axis of atten_matrix.
    tgt_len:  Integer length to use to trim the tgt_time axis of atten_matrix.
    transcript: transcript for the target sequence.
    **kwargs:  Additional keyword args to pass to plot.AddImage.
  """
    plot.AddImage(fig,
                  axes,
                  atten_matrix[:tgt_len, :src_len],
                  clim=(0, 1),
                  **kwargs)
    if transcript is not None:
        if isinstance(transcript, np.ndarray):
            transcript = ' '.join(transcript[:src_len])
        axes.set_xlabel(plot.ToUnicode(transcript), size='x-small', wrap=True)
Пример #2
0
    def testToUnicode(self):
        str_str = 'pójdź kińże tę chmurność w głąb flaszy'
        uni_str = u'pójdź kińże tę chmurność w głąb flaszy'

        self.assertEqual(plot.ToUnicode(str_str), uni_str)
        self.assertEqual(plot.ToUnicode(str_str), plot.ToUnicode(uni_str))
Пример #3
0
 def PlotAttention(fig, axes, cur_atten_probs, title, set_x_label):
   plot.AddImage(fig, axes, cur_atten_probs, title=title)
   axes.set_ylabel(plot.ToUnicode('Output sequence index'), wrap=True)
   if set_x_label:
     axes.set_xlabel(plot.ToUnicode('Input sequence index'), wrap=True)