def TrimPaddingAndPlotAttention(fig, axes, atten_matrix, src_len, tgt_len, transcript=None, **kwargs): """Trims axes of atten_matrix with shape (tgt_time, src_time) and plots it. For use as a plot function with MatplotlibFigureSummary. Args: fig: A matplotlib figure handle. axes: A matplotlib axes handle. atten_matrix: A 2D ndarray shaped (tgt_time, src_time). src_len: Integer length to use to trim the src_time axis of atten_matrix. tgt_len: Integer length to use to trim the tgt_time axis of atten_matrix. transcript: transcript for the target sequence. **kwargs: Additional keyword args to pass to plot.AddImage. """ plot.AddImage(fig, axes, atten_matrix[:tgt_len, :src_len], clim=(0, 1), **kwargs) if transcript is not None: if isinstance(transcript, np.ndarray): transcript = ' '.join(transcript[:src_len]) axes.set_xlabel(plot.ToUnicode(transcript), size='x-small', wrap=True)
def testToUnicode(self): str_str = 'pójdź kińże tę chmurność w głąb flaszy' uni_str = u'pójdź kińże tę chmurność w głąb flaszy' self.assertEqual(plot.ToUnicode(str_str), uni_str) self.assertEqual(plot.ToUnicode(str_str), plot.ToUnicode(uni_str))
def PlotAttention(fig, axes, cur_atten_probs, title, set_x_label): plot.AddImage(fig, axes, cur_atten_probs, title=title) axes.set_ylabel(plot.ToUnicode('Output sequence index'), wrap=True) if set_x_label: axes.set_xlabel(plot.ToUnicode('Input sequence index'), wrap=True)