Пример #1
0
def visualize_attributions(visualization_list):
    """
    Helper function to call Captum's visualization methods.
    Inputs:
    visualization_list: a list containing the integrated gradients attributions.
    """

    viz.visualize_text(visualization_list)
Пример #2
0
    def visualize_attribution(self, example, save=False):
        e1_pos = example.e1_idx
        e2_pos = example.e2_idx
        label = example.label
        is_same_sentence = example.sent1 == example.sent2
        
        text, inputs = self._encode(example)

        input_ids, attention_mask = inputs["input_ids"], \
                                    inputs["attention_mask"]

        # Creates mappings from words in original text to tokens.
        tok_to_orig_index = []
        orig_to_tok_index = []
        all_doc_tokens = []
        for (i, word) in enumerate(text):
          orig_to_tok_index.append(len(all_doc_tokens))
          tokens = self.tokenizer.tokenize(word)
          for token in tokens:
            tok_to_orig_index.append(i)
            all_doc_tokens.append(token)

        tok_e1_idx = orig_to_tok_index[e1_pos] + 1
        tok_e2_idx = orig_to_tok_index[e2_pos] + 1
        if not is_same_sentence:
          tok_e2_idx += 2

        pred_scores, all_tokens, attributions_sum, delta = self.get_scores_and_attributions(inputs,
                                                                                     tok_e1_idx,
                                                                                     tok_e2_idx,
                                                                                     label)

        # storing couple samples in an array for visualization purposes
        cls_vis = viz.VisualizationDataRecord(
                                word_attributions=attributions_sum,
                                pred_prob=torch.max(torch.softmax(pred_scores[0], dim=0)),
                                pred_class=CLASSES[torch.argmax(pred_scores)],
                                true_class=label,
                                attr_class=CLASSES[torch.argmax(pred_scores)],
                                attr_score=attributions_sum.sum(),       
                                raw_input=all_tokens,
                                convergence_score=delta)

        print('\033[1m', 'Visualizations For Prediction', '\033[0m')
        viz.visualize_text([cls_vis])
        print(example.id)
        print(example.sent1[example.e1_idx])
        print(example.sent2[example.e2_idx])
        if save:
            attributions_per_example[example.id] = AttributionPerExample(example, input_ids, self.tokenizer.convert_ids_to_tokens(input_ids), attributions_sum)

            obj = save_viz([cls_vis])
            with open(model_path + example.id + ".html", "w") as png:
                png.write(obj.data)
Пример #3
0
    def visualize(self, html_filepath: str = None, true_class: str = None):
        tokens = [
            token.replace("Ġ", "") for token in self.decode(self.input_ids)
        ]
        attr_class = self.id2label[self.selected_index]

        if self._single_node_output:
            if true_class is None:
                true_class = round(float(self.pred_probs))
            predicted_class = round(float(self.pred_probs))
            attr_class = round(float(self.pred_probs))
        else:
            if true_class is None:
                true_class = self.selected_index
            predicted_class = self.predicted_class_name

        score_viz = self.attributions.visualize_attributions(  # type: ignore
            self.pred_probs,
            predicted_class,
            true_class,
            attr_class,
            tokens,
        )
        html = viz.visualize_text([score_viz])

        if html_filepath:
            if not html_filepath.endswith(".html"):
                html_filepath = html_filepath + ".html"
            with open(html_filepath, "w") as html_file:
                html_file.write(html.data)
        return html
Пример #4
0
    def visualize(self, html_filepath: str = None, true_class: str = None):
        """
        Visualizes word attributions. If in a notebook table will be displayed inline.

        Otherwise pass a valid path to `html_filepath` and the visualization will be saved
        as a html file.

        If the true class is known for the text that can be passed to `true_class`

        """
        tokens = [
            token.replace("Ġ", "") for token in self.decode(self.input_ids)
        ]

        if not self.include_hypothesis:
            tokens = tokens[:self.sep_idx]

        score_viz = [
            self.attributions[i].visualize_attributions(  # type: ignore
                self.pred_probs[i],
                self.labels[i],
                self.labels[i],
                self.labels[i],
                tokens,
            ) for i in range(len(self.attributions))
        ]
        html = viz.visualize_text(score_viz)

        if html_filepath:
            if not html_filepath.endswith(".html"):
                html_filepath = html_filepath + ".html"
            with open(html_filepath, "w") as html_file:
                html_file.write(html.data)
        return html
Пример #5
0
    def visualize(self,
                  html_filepath: str = None,
                  true_classes: List[str] = None):
        """
        Visualizes word attributions. If in a notebook table will be displayed inline.

        Otherwise pass a valid path to `html_filepath` and the visualization will be saved
        as a html file.

        If the true class is known for the text that can be passed to `true_class`

        """
        if true_classes is not None and len(
                true_classes) != self.input_ids.shape[1]:
            raise ValueError(
                f"""The length of `true_classes` must be equal to the number of tokens"""
            )

        score_vizs = []
        tokens = [
            token.replace("Ġ", "") for token in self.decode(self.input_ids)
        ]

        for index in self._selected_indexes:
            pred_prob = torch.max(self.pred_probs[index])
            predicted_class = self.id2label[torch.argmax(
                self.pred_probs[index]).item()]

            attr_class = tokens[index]
            if true_classes is None:
                true_class = predicted_class
            else:
                true_class = true_classes[index]

            score_vizs.append(self.attributions[index].visualize_attributions(
                pred_prob,
                predicted_class,
                true_class,
                attr_class,
                tokens,
            ))

        html = viz.visualize_text(score_vizs)

        if html_filepath:
            if not html_filepath.endswith(".html"):
                html_filepath = html_filepath + ".html"
            with open(html_filepath, "w") as html_file:
                html_file.write(html.data)
        return html
Пример #6
0
    def visualize(self, html_filepath: str = None, true_class: str = None):
        tokens = self.tokenizer.convert_ids_to_tokens(self.input_ids[0])
        attr_class = self.id2label[int(self.selected_index)]
        if true_class is None:
            true_class = self.predicted_class_name
        score_viz = self.attributions.visualize_attributions(
            self.pred_probs, self.predicted_class_name, true_class, attr_class,
            self.text, tokens)
        html = viz.visualize_text([score_viz])

        if html_filepath:
            if not html_filepath.endswith(".html"):
                html_filepath = html_filepath + ".html"
            with open(html_filepath, "w") as html_file:
                html_file.write(html.data)
Пример #7
0
def interpret_sentence(docid):
    vis_data_records_ig = []
    model.zero_grad()
    tokens, token_ids, label = get_tokens_by_docid(docid)
    pred = forward_with_sigmoid(data.x, data.edge_index, data.edge_attr,
                                docid)[label]
    pred_ind = round(pred.detach().cpu().item())
    # compute attributions and approximation delta using layer integrated gradients
    token_reference = TokenReferenceBase(reference_token_idx=0)
    reference_indices = token_reference.generate_reference(
        data.x.shape[0], device='cuda:3').unsqueeze(0)
    attributions_ig, delta = lig.attribute(
        data.x.unsqueeze(0),
        reference_indices,
        additional_forward_args=(data.edge_index.unsqueeze(0),
                                 data.edge_attr.unsqueeze(0), docid),
        n_steps=50,
        return_convergence_delta=True,
        internal_batch_size=1)
    print(f'pred: {pred}, delta: {abs(delta)}')
    print(attributions_ig)
    add_attributions_to_visualizer(attributions_ig, tokens, token_ids, pred,
                                   pred_ind, label, delta, vis_data_records_ig)
    visualization.visualize_text(vis_data_records_ig)
    def visualize(self, html_filepath: str = None):
        """
        Visualizes word attributions. If in a notebook table will be displayed inline.

        Otherwise pass a valid path to `html_filepath` and the visualization will be saved
        as a html file.
        """
        tokens = [
            token.replace("Ġ", "") for token in self.decode(self.input_ids)
        ]
        predicted_answer = self.predicted_answer

        self.position = 0
        start_pred_probs = self._forward(self.input_ids, self.token_type_ids,
                                         self.position_ids)
        start_pos = self.start_pos
        start_pos_str = tokens[start_pos] + " (" + str(start_pos) + ")"
        start_score_viz = self.start_attributions.visualize_attributions(
            float(start_pred_probs),
            str(predicted_answer),
            start_pos_str,
            start_pos_str,
            tokens,
        )

        self.position = 1

        end_pred_probs = self._forward(self.input_ids, self.token_type_ids,
                                       self.position_ids)
        end_pos = self.end_pos
        end_pos_str = tokens[end_pos] + " (" + str(end_pos) + ")"
        end_score_viz = self.end_attributions.visualize_attributions(
            float(end_pred_probs),
            str(predicted_answer),
            end_pos_str,
            end_pos_str,
            tokens,
        )

        html = viz.visualize_text([start_score_viz, end_score_viz])

        if html_filepath:
            if not html_filepath.endswith(".html"):
                html_filepath = html_filepath + ".html"
            with open(html_filepath, "w") as html_file:
                html_file.write(html.data)
        return html
    def visualize(self, html_filepath: str = None):
        tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
        predicted_answer = self.predicted_answer

        self.position = 0
        start_pred_probs = self._forward(
            self.input_ids, self.token_type_ids, self.position_ids
        )
        start_pos = self.start_pos
        start_pos_str = tokens[start_pos] + " (" + str(start_pos) + ")"
        start_score_viz = self.start_attributions.visualize_attributions(
            float(start_pred_probs),
            str(predicted_answer),
            start_pos_str,
            start_pos_str,
            tokens,
        )

        self.position = 1

        end_pred_probs = self._forward(
            self.input_ids, self.token_type_ids, self.position_ids
        )
        end_pos = self.end_pos
        end_pos_str = tokens[end_pos] + " (" + str(end_pos) + ")"
        end_score_viz = self.end_attributions.visualize_attributions(
            float(end_pred_probs),
            str(predicted_answer),
            end_pos_str,
            end_pos_str,
            tokens,
        )

        html = viz.visualize_text([start_score_viz, end_score_viz])

        if html_filepath:
            if not html_filepath.endswith(".html"):
                html_filepath = html_filepath + ".html"
            with open(html_filepath, "w") as html_file:
                html_file.write(html.data)
        return html
Пример #10
0
    def visualize(self, html_filepath: str = None, true_class: str = None):
        """
        Visualizes word attributions. If in a notebook table will be displayed inline.

        Otherwise pass a valid path to `html_filepath` and the visualization will be saved
        as a html file.

        If the true class is known for the text that can be passed to `true_class`

        """
        tokens = [
            token.replace("Ġ", "") for token in self.decode(self.input_ids)
        ]
        attr_class = self.id2label[self.selected_index]

        if self._single_node_output:
            if true_class is None:
                true_class = round(float(self.pred_probs))
            predicted_class = round(float(self.pred_probs))
            attr_class = round(float(self.pred_probs))
        else:
            if true_class is None:
                true_class = self.selected_index
            predicted_class = self.predicted_class_name

        score_viz = self.attributions.visualize_attributions(  # type: ignore
            self.pred_probs,
            predicted_class,
            true_class,
            attr_class,
            tokens,
        )
        html = viz.visualize_text([score_viz])

        if html_filepath:
            if not html_filepath.endswith(".html"):
                html_filepath = html_filepath + ".html"
            with open(html_filepath, "w") as html_file:
                html_file.write(html.data)
        return html
Пример #11
0
    def visualize(self, html_filepath: str = None, true_class: str = None):
        """
        Visualizes word attributions. If in a notebook table will be displayed inline.

        Otherwise pass a valid path to `html_filepath` and the visualization will be saved
        as a html file.

        If the true class is known for the text that can be passed to `true_class`

        """
        tokens = [
            token.replace("Ġ", "") for token in self.decode(self.input_ids)
        ]

        score_viz = [
            self.attributions[i].visualize_attributions(  # type: ignore
                self.pred_probs[i],
                "",  # including a predicted class name does not make sense for this explainer
                "n/a" if not true_class else
                true_class,  # no true class name for this explainer by default
                self.labels[i],
                tokens,
            ) for i in range(len(self.attributions))
        ]

        html = viz.visualize_text(score_viz)

        new_html_data = html._repr_html_().replace("Predicted Label",
                                                   "Prediction Score")
        new_html_data = new_html_data.replace("True Label", "n/a")
        html.data = new_html_data

        if html_filepath:
            if not html_filepath.endswith(".html"):
                html_filepath = html_filepath + ".html"
            with open(html_filepath, "w") as html_file:
                html_file.write(html.data)
        return html
Пример #12
0
                            attributions.sum(),       
                            text,
                            delta))

interpret_sentence(model, 'It was a fantastic performance !', label=1)
interpret_sentence(model, 'Best film ever', label=1)
interpret_sentence(model, 'Such a great show!', label=1)
interpret_sentence(model, 'It was a horrible movie', label=0)
interpret_sentence(model, 'I\'ve never watched something as bad', label=0)
interpret_sentence(model, 'It is a disgusting movie!', label=0)

interpret_sentence(model, 'It was not a bad movie!', label=1)
interpret_sentence(model, 'I would like watch this movie again', label=1)

print('Visualize attributions based on Integrated Gradients')
visualization.visualize_text(vis_data_records_ig)

try:
  # tqdm newline issue: https://stackoverflow.com/questions/41707229/tqdm-printing-to-newline
  tqdm._instances.clear()
except:
  pass

TEXT.build_vocab(trn, vectors="glove.6B.300d")
# подсказка: один из импортов пока не использовался, быть может он нужен в строке выше :)
LABEL.build_vocab(trn)
word_embeddings = TEXT.vocab.vectors

kernel_sizes = [3, 4, 5]
vocab_size = len(TEXT.vocab)
dropout = 0.5
Пример #13
0
        for label_attribution in example_attributions:
            label_id = label_attribution['label_id']
            true_label = label_attribution['true_label']
            example_index = label_attribution['example_index']
            token_index = np.where(
                preds[example_index] == label_attribution['label_id'])[0][0]
            attributions_sum = label_attribution['attributions']
            delta = label_attribution['delta']
            score_vis = viz.VisualizationDataRecord(
                word_attributions=attributions_sum,
                pred_prob=max(confidences[example_index][token_index]),
                pred_class=labels[label_id],
                true_class=labels[true_label],
                attr_class=labels[label_id],
                attr_score=attributions_sum.sum(),
                raw_input=all_tokens[example_index],
                convergence_score=delta)
            all_visualizations.append(score_vis)

    index += 1

if all_visualizations:
    display = viz.visualize_text(all_visualizations)

    with open(f'{args.explanations_dir}/explanations.html', 'w') as file:
        file.write(display.data)

whole_process_end = time.time()
whole_process_duration = round(whole_process_end - whole_process_start, 2)
logger.info(f'Whole process took {whole_process_duration} seconds')
Пример #14
0
 def visualize(self):
     visualization.visualize_text(self.vis_data_records_ig)
     self.vis_data_records_ig = []
Пример #15
0
 def show_words_importance(self, vis_data_records_ig):
     visualization.visualize_text(vis_data_records_ig)
Пример #16
0
def captum_text_interpreter(text,
                            model,
                            bpetokenizer,
                            idx2label,
                            max_len=80,
                            tokenizer=None,
                            multiclass=False):
    if type(text) == list:
        text = " ".join(text)

    d = data_utils.process_data_for_transformers(text, bpetokenizer, tokenizer,
                                                 0)
    d = {
        "ids": torch.tensor([d['ids']], dtype=torch.long),
        "mask": torch.tensor([d['mask']], dtype=torch.long),
        "token_type_ids": torch.tensor([d['token_type_ids']], dtype=torch.long)
    }

    try:
        orig_tokens = [0] + bpetokenizer.encode(text).ids + [2]
        orig_tokens = [bpetokenizer.id_to_token(j) for j in orig_tokens]
    except:
        orig_tokens = tokenizer.tokenize(text, add_special_tokens=True)

    model.eval()
    if multiclass:
        preds_proba = torch.sigmoid(
            model(d["ids"], d["mask"],
                  d["token_type_ids"])).detach().cpu().numpy()
        preds = preds_proba.argmax(-1)
        preds_proba = preds_proba[0][preds[0][0]]
        predicted_class = idx2label[preds[0][0]]
    else:
        preds_proba = torch.sigmoid(
            model(d["ids"], d["mask"],
                  d["token_type_ids"])).detach().cpu().numpy()
        preds = np.round(preds_proba)
        preds_proba = preds_proba[0][0]
        predicted_class = idx2label[preds[0][0]]

    lig = LayerIntegratedGradients(model, model.base_model.embeddings)

    reference_indices = [0] + [1] * (d["ids"].shape[1] - 2) + [2]
    reference_indices = torch.tensor([reference_indices], dtype=torch.long)

    attributions_ig, delta = lig.attribute(inputs=d["ids"],baselines=reference_indices,additional_forward_args=(d["mask"],d["token_type_ids"]), \
                                           return_convergence_delta=True)

    attributions = attributions_ig.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.detach().cpu().numpy()

    visualization.visualize_text([
        visualization.VisualizationDataRecord(word_attributions=attributions,
                                              pred_prob=preds_proba,
                                              pred_class=predicted_class,
                                              true_class=predicted_class,
                                              attr_class=predicted_class,
                                              attr_score=attributions.sum(),
                                              raw_input=orig_tokens,
                                              convergence_score=delta)
    ])