コード例 #1
0
ファイル: Decoder.py プロジェクト: jogonba2/SHTE
    def _decode_sample(self, x, topk,
                       average_encoders = False,
                       selected_encoder = -1,
                       visualization = False,
                       rows_visualization = None,
                       cols_visualization = None):

        attns, rx = self._get_sample_attns(x, average_encoders = average_encoders,
                                           selected_encoder = selected_encoder)

        if visualization: Visualization.visualize_attentions(attns, 16, 16, rows_visualization, cols_visualization)
        attn = self._get_sample_attn(attns)
        x_lines = x.split(" . ")
        lx_lines = len(x_lines)
        sent_pad_required = max(0, self.document_max_sents - lx_lines)
        if visualization: Visualization.visualize_attentions(attn, 16, 16, 1, 1)
        attn = attn[sent_pad_required:, sent_pad_required:]
        if visualization: Visualization.visualize_attentions(attn, 16, 16, 1, 1)
        sentence_attn = attn.sum(axis = 0) / attn.shape[0]
        topk_sentences = sorted(np.argsort(sentence_attn)[::-1][:topk])
        if visualization: Visualization.visualize_attentions(sentence_attn, 16, 16, 1, 1)
        summary = [x_lines[i] for i in topk_sentences]
        return " . ".join(summary)
コード例 #2
0
 l_words = s.split()
 print("Pred: %s" % (pred))
 print("Truth: %s" % (y_dv[inp]))
 pad = 0
 while len(l_words) < max_words:
     l_words.insert(0, "<pad>")
     pad += 1
 l_words = l_words[pad:]
 attn_i = attn_i[:, pad:, pad:]
 attn_i = attn_i[2]
 #attn_i = attn_i.sum(axis=0)
 output_file = "sentiment_att.pdf"
 Visualization.visualize_attentions(attn_i,
                                    30,
                                    30,
                                    rows=1,
                                    columns=1,
                                    ticks=l_words,
                                    output_file=output_file,
                                    save=True)
 inp = int(input())
 """
 attn_i = attns[inp]
 l_words = x_dv[inp].split()
 print("Pred: %s, True: %s" % (preds[inp], truths[inp]))
 pad = 0
 while len(l_words) < max_words:
     l_words.insert(0, "<pad>")
     pad += 1
 l_words = l_words[pad:]
 attn_i = attn_i[:, pad:, pad:]
 #attn_i = attn_i.sum(axis=0)