示例#1
0
    def add_attributions_to_visualizer(
        self,
        attributions,
        tokens,
        pred_prob,
        pred_class,
        true_class,
        attr_class,
        delta,
        vis_data_records,
    ):
        """Adds attribution to visualizer."""
        attributions = attributions.sum(dim=2).squeeze(0)
        attributions = attributions / torch.norm(attributions)
        attributions = attributions.cpu().detach().numpy()

        # storing couple samples in an array for visualization purposes
        vis_data_records.append(
            visualization.VisualizationDataRecord(
                attributions,
                pred_prob,
                pred_class,
                true_class,
                attr_class,
                attributions.sum(),
                tokens,
                delta,
            ))
示例#2
0
 def add_attributions_to_visualizer(self, attributions, tokens, pred,
                                    pred_ind, label, delta,
                                    vis_data_records):
     # storing couple samples in an array for visualization purposes
     vis_data_records.append(
         visualization.VisualizationDataRecord(attributions[1:-1], pred,
                                               pred_ind, label, "label",
                                               attributions.sum(),
                                               tokens[1:-1], delta))
def add_attr_viz(attributions, text, pred, pred_ind, label, delta,
                 vis_data_records):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions /= torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()
    vis_data_records.append(
        visualization.VisualizationDataRecord(attributions, pred,
                                              classes[pred_ind],
                                              classes[label], 'location',
                                              attributions.sum(), text, delta))
示例#4
0
    def visualize_attribution(self, example, save=False):
        e1_pos = example.e1_idx
        e2_pos = example.e2_idx
        label = example.label
        is_same_sentence = example.sent1 == example.sent2
        
        text, inputs = self._encode(example)

        input_ids, attention_mask = inputs["input_ids"], \
                                    inputs["attention_mask"]

        # Creates mappings from words in original text to tokens.
        tok_to_orig_index = []
        orig_to_tok_index = []
        all_doc_tokens = []
        for (i, word) in enumerate(text):
          orig_to_tok_index.append(len(all_doc_tokens))
          tokens = self.tokenizer.tokenize(word)
          for token in tokens:
            tok_to_orig_index.append(i)
            all_doc_tokens.append(token)

        tok_e1_idx = orig_to_tok_index[e1_pos] + 1
        tok_e2_idx = orig_to_tok_index[e2_pos] + 1
        if not is_same_sentence:
          tok_e2_idx += 2

        pred_scores, all_tokens, attributions_sum, delta = self.get_scores_and_attributions(inputs,
                                                                                     tok_e1_idx,
                                                                                     tok_e2_idx,
                                                                                     label)

        # storing couple samples in an array for visualization purposes
        cls_vis = viz.VisualizationDataRecord(
                                word_attributions=attributions_sum,
                                pred_prob=torch.max(torch.softmax(pred_scores[0], dim=0)),
                                pred_class=CLASSES[torch.argmax(pred_scores)],
                                true_class=label,
                                attr_class=CLASSES[torch.argmax(pred_scores)],
                                attr_score=attributions_sum.sum(),       
                                raw_input=all_tokens,
                                convergence_score=delta)

        print('\033[1m', 'Visualizations For Prediction', '\033[0m')
        viz.visualize_text([cls_vis])
        print(example.id)
        print(example.sent1[example.e1_idx])
        print(example.sent2[example.e2_idx])
        if save:
            attributions_per_example[example.id] = AttributionPerExample(example, input_ids, self.tokenizer.convert_ids_to_tokens(input_ids), attributions_sum)

            obj = save_viz([cls_vis])
            with open(model_path + example.id + ".html", "w") as png:
                png.write(obj.data)
示例#5
0
    def visualize_attributions(self, pred_prob, pred_class, true_class,
                               attr_class, text, all_tokens):

        return viz.VisualizationDataRecord(
            self.attributions_sum,
            pred_prob,
            pred_class,
            true_class,
            attr_class,
            self.attributions_sum.sum(),
            all_tokens,
            self.delta,
        )
示例#6
0
def add_attributions_to_visualizer(attributions, text, pred, pred_ind, label, delta, vis_data_records):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()

    # storing couple samples in an array for visualization purposes
    vis_data_records.append(visualization.VisualizationDataRecord(
                            attributions,
                            pred,
                            LABEL.vocab.itos[pred_ind],
                            LABEL.vocab.itos[label],
                            LABEL.vocab.itos[1],
                            attributions.sum(),       
                            text,
                            delta))
示例#7
0
def add_attributions_to_visualizer(attributions, text, pred, pred_ind, delta,
                                   vis_data_records, vectors):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()
    pred_ind = pred_ind.item()

    # storing couple samples in an array for visualization purposes
    vis_data_records.append(
        visualization.VisualizationDataRecord(attributions, pred,
                                              vectors[pred_ind],
                                              vectors[pred_ind], vectors[1],
                                              attributions.sum(),
                                              text[:len(attributions)],
                                              delta.cpu()))
示例#8
0
    def add_attributions_to_visualizer(self, attributions, text, pred, pred_ind,
                                       label, delta, target):
        attributions = attributions.sum(dim=2).squeeze(0)
        attributions = attributions / torch.norm(attributions)
        attributions = attributions.cpu().detach().numpy()

        # storing couple samples in an array for visualization purposes
        self.vis_data_records_ig.append(visualization.VisualizationDataRecord(
            attributions,
            pred,
            self.label_idx[pred_ind],
            self.label_idx[label],
            self.label_idx[target],
            attributions.sum(),
            text,
            delta))
示例#9
0
def get_shap_attributions(text, tech_tv, tech_bb, true_label):
    try:
        model, tokenizer = get_model(settings.MODEL_NAME, settings.MODEL_STAGE)
    except Exception as e:
        logger.info(f"wtf is going on here: {e}")
    ref_token_id = tokenizer.pad_token_id  # A token used for generating token reference
    sep_token_id = (
        tokenizer.sep_token_id
    )  # A token used as a separator between question and text and it is also added to the end of the text.
    cls_token_id = tokenizer.cls_token_id  # A token used for prepending to the concatenated question-text word sequence

    bert_string = data_utils.stitch_bert_string("", text, tech_tv, tech_bb)

    input_ids, ref_input_ids, _ = construct_input_ref_pair(tokenizer, bert_string, ref_token_id, sep_token_id, cls_token_id)

    indices = input_ids[0].detach().tolist()
    all_tokens = tokenizer.convert_ids_to_tokens(indices)

    pred = model(input_ids)[0]
    pred_proba = torch.softmax(pred, dim=1)[0]
    model_custom = ModelWrapper(model)

    shap = ShapleyValueSampling(model_custom.custom_forward)
    attributions = shap.attribute(
        inputs=input_ids,
        baselines=ref_input_ids,
        target=torch.argmax(pred[0]),
    )

    score_vis = viz.VisualizationDataRecord(
        attributions[0, :],
        torch.softmax(pred, dim=1)[0][torch.argmax(pred[0]).cpu().numpy().item()],
        model.config.id2label[torch.argmax(pred[0]).cpu().numpy().item()],
        true_label,
        model.config.id2label[torch.argmax(pred[0]).cpu().numpy().item()],
        attributions.sum(),
        all_tokens,
        0,
    )

    labels = list(model.config.id2label.values())

    return score_vis, pred_proba, labels
示例#10
0
def captum_interactive(request):
    if request.method == 'POST':
        STORED_POSTS = request.session.get("TextAttackResult")
        form = CustomData(request.POST)
        if form.is_valid():
            input_text, model_name, recipe_name = form.cleaned_data[
                'input_text'], form.cleaned_data[
                    'model_name'], form.cleaned_data['recipe_name']
            found = False
            if STORED_POSTS:
                JSON_STORED_POSTS = json.loads(STORED_POSTS)
                for idx, el in enumerate(JSON_STORED_POSTS):
                    if el["type"] == "captum" and el[
                            "input_string"] == input_text:
                        tmp = JSON_STORED_POSTS.pop(idx)
                        JSON_STORED_POSTS.insert(0, tmp)
                        found = True
                        break

                if found:
                    request.session["TextAttackResult"] = json.dumps(
                        JSON_STORED_POSTS[:10])
                    return HttpResponseRedirect(reverse('webdemo:index'))

            original_model = transformers.AutoModelForSequenceClassification.from_pretrained(
                "textattack/" + model_name)
            original_tokenizer = textattack.models.tokenizers.AutoTokenizer(
                "textattack/" + model_name)
            model = textattack.models.wrappers.HuggingFaceModelWrapper(
                original_model, original_tokenizer)

            device = torch.device(
                "cuda:2" if torch.cuda.is_available() else "cpu")
            clone = deepcopy(model)
            clone.model.to(device)

            def calculate(input_ids, token_type_ids, attention_mask):
                return clone.model(input_ids, token_type_ids,
                                   attention_mask)[0]

            attack = textattack.commands.attack.attack_args_helpers.parse_attack_from_args(
                Args(model_name, recipe_name))
            attacked_text = textattack.shared.attacked_text.AttackedText(
                input_text)
            attack.goal_function.init_attack_example(attacked_text, 1)
            goal_func_result, _ = attack.goal_function.get_result(
                attacked_text)

            result = next(
                attack.attack_dataset([(input_text, goal_func_result.output)]))
            result_parsed = result.str_lines()
            if len(result_parsed) < 3:
                return HttpResponseNotFound('Failed')
            output_text = result_parsed[2]

            attacked_text_out = textattack.shared.attacked_text.AttackedText(
                output_text)

            orig = result.original_text()
            pert = result.perturbed_text()

            encoded = model.tokenizer.batch_encode([orig])
            batch_encoded = captum_form(encoded, device)
            x = calculate(**batch_encoded)

            pert_encoded = model.tokenizer.batch_encode([pert])
            pert_batch_encoded = captum_form(pert_encoded, device)
            x_pert = calculate(**pert_batch_encoded)

            lig = LayerIntegratedGradients(calculate,
                                           clone.model.bert.embeddings)
            attributions, delta = lig.attribute(
                inputs=batch_encoded['input_ids'],
                additional_forward_args=(batch_encoded['token_type_ids'],
                                         batch_encoded['attention_mask']),
                n_steps=10,
                target=torch.argmax(calculate(**batch_encoded)).item(),
                return_convergence_delta=True)

            attributions_pert, delta_pert = lig.attribute(
                inputs=pert_batch_encoded['input_ids'],
                additional_forward_args=(pert_batch_encoded['token_type_ids'],
                                         pert_batch_encoded['attention_mask']),
                n_steps=10,
                target=torch.argmax(calculate(**pert_batch_encoded)).item(),
                return_convergence_delta=True)

            orig = original_tokenizer.tokenizer.tokenize(orig)
            pert = original_tokenizer.tokenizer.tokenize(pert)

            atts = attributions.sum(dim=-1).squeeze(0)
            atts = atts / torch.norm(atts)

            atts_pert = attributions_pert.sum(dim=-1).squeeze(0)
            atts_pert = atts_pert / torch.norm(atts)

            all_tokens = original_tokenizer.tokenizer.convert_ids_to_tokens(
                batch_encoded['input_ids'][0])
            all_tokens_pert = original_tokenizer.tokenizer.convert_ids_to_tokens(
                pert_batch_encoded['input_ids'][0])

            v = viz.VisualizationDataRecord(atts[:45].detach().cpu(),
                                            torch.max(x).item(),
                                            torch.argmax(x, dim=1).item(),
                                            goal_func_result.output, 2,
                                            atts.sum().detach(),
                                            all_tokens[:45], delta)

            v_pert = viz.VisualizationDataRecord(
                atts_pert[:45].detach().cpu(),
                torch.max(x_pert).item(),
                torch.argmax(x_pert, dim=1).item(), goal_func_result.output, 2,
                atts_pert.sum().detach(), all_tokens_pert[:45], delta_pert)

            formattedHTML = formatDisplay([v, v_pert])

            post = {
                "type": "captum",
                "input_string": input_text,
                "model_name": model_name,
                "recipe_name": recipe_name,
                "output_string": output_text,
                "html_input_string": formattedHTML[0],
                "html_output_string": formattedHTML[1],
            }

            if STORED_POSTS:
                JSON_STORED_POSTS = json.loads(STORED_POSTS)
                JSON_STORED_POSTS.insert(0, post)
                request.session["TextAttackResult"] = json.dumps(
                    JSON_STORED_POSTS[:10])
            else:
                request.session["TextAttackResult"] = json.dumps([post])

            return HttpResponseRedirect(reverse('webdemo:index'))

        else:
            return HttpResponseNotFound('Failed')

        return HttpResponse('Success')

    return HttpResponseNotFound('<h1>Not Found</h1>')
示例#11
0
    if all_attributions:
        example_attributions = all_attributions[index]
        for label_attribution in example_attributions:
            label_id = label_attribution['label_id']
            true_label = label_attribution['true_label']
            example_index = label_attribution['example_index']
            token_index = np.where(
                preds[example_index] == label_attribution['label_id'])[0][0]
            attributions_sum = label_attribution['attributions']
            delta = label_attribution['delta']
            score_vis = viz.VisualizationDataRecord(
                word_attributions=attributions_sum,
                pred_prob=max(confidences[example_index][token_index]),
                pred_class=labels[label_id],
                true_class=labels[true_label],
                attr_class=labels[label_id],
                attr_score=attributions_sum.sum(),
                raw_input=all_tokens[example_index],
                convergence_score=delta)
            all_visualizations.append(score_vis)

    index += 1

if all_visualizations:
    display = viz.visualize_text(all_visualizations)

    with open(f'{args.explanations_dir}/explanations.html', 'w') as file:
        file.write(display.data)

whole_process_end = time.time()
示例#12
0
def interpret_sentence(flair_model_wrapper,
                       lig,
                       sentence,
                       target_label,
                       visualization_list,
                       n_steps=100,
                       estimation_method="gausslegendre",
                       internal_batch_size=None):
    """
    We can visualise the attributions made by making use of Pytorch Captum.
    Inputs:
    flair_model_wrapper: class containing a customized forward function of Flair model.
    lig: the layer integrated gradient object.
    sentence: the Flair sentence-object we want to interpret.
    target_label: the ground truth class-label of the sentence.
    visualization_list: a list to store the visualization records in.
    """

    # Return the target index from the label dictionary.
    target_index = flair_model_wrapper.label_dictionary.get_idx_for_item(
        target_label)

    # In order maintain consistency with Flair, we apply the same tokenization
    # steps.
    flair_sentence = Sentence(sentence)

    tokenized_sentence = flair_sentence.to_tokenized_string()

    tokenizer_max_length = flair_model_wrapper.tokenizer.model_max_length

    # This calculates the token input IDs tensor for the model.
    input_ids = flair_model_wrapper.tokenizer.encode(
        tokenized_sentence,
        add_special_tokens=False,
        max_length=tokenizer_max_length,
        truncation=True,
        return_tensors="pt")

    # this makes sure that the input IDs tensor is on the correct device (cuda or cpu)
    input_ids = input_ids.to(flair_model_wrapper.device)

    # Create a baseline by creating a tensor of equal length
    # containing the padding token tensor id.
    pad_token_id = flair_model_wrapper.tokenizer.pad_token_id

    ref_base_line = torch.full_like(input_ids, pad_token_id)

    # Convert back to tokens as the model requires.
    # As some words might get split up. e.g. Caroll to Carol l.
    all_tokens = flair_model_wrapper.tokenizer.convert_ids_to_tokens(
        input_ids[0])

    # The tokenizer in the model adds a special character
    # in front of every sentence.
    readable_tokens = [token.replace("▁", "") for token in all_tokens]

    # The input IDs are passed to the embedding layer of the model.
    # It is better to return the logits for Captum.
    # https://github.com/pytorch/captum/issues/355#issuecomment-619610044
    # Thus we calculate the softmax afterwards.
    # For now, I take the first dimension and run this sentence, per sentence.
    model_outputs = flair_model_wrapper(input_ids)

    softmax = torch.nn.functional.softmax(model_outputs[0], dim=0)

    # Return the confidence and the class ID of the top predicted class.
    conf, idx = torch.max(softmax, 0)

    # Returns the probability.
    prediction_confidence = conf.item()

    # Returns the label name from the top prediction class.
    pred_label = flair_model_wrapper.label_dictionary.get_item_for_index(
        idx.item())

    # Calculate the attributions according to the LayerIntegratedGradients method.
    attributions_ig, delta = lig.attribute(
        input_ids,
        baselines=ref_base_line,
        n_steps=n_steps,
        return_convergence_delta=True,
        target=target_index,
        method=estimation_method,
        internal_batch_size=internal_batch_size)

    convergence_delta = abs(delta)
    print('pred: ', idx.item(), '(', '%.2f' % conf.item(), ')', ', delta: ',
          convergence_delta)

    word_attributions, attribution_score = summarize_attributions(
        attributions_ig)

    visualization_list.append(
        viz.VisualizationDataRecord(word_attributions=word_attributions,
                                    pred_prob=prediction_confidence,
                                    pred_class=pred_label,
                                    true_class=target_label,
                                    attr_class=target_label,
                                    attr_score=attribution_score,
                                    raw_input=readable_tokens,
                                    convergence_score=delta))

    # Return these for the sanity checks.
    return readable_tokens, word_attributions, convergence_delta
示例#13
0
def captum_text_interpreter(text,
                            model,
                            bpetokenizer,
                            idx2label,
                            max_len=80,
                            tokenizer=None,
                            multiclass=False):
    if type(text) == list:
        text = " ".join(text)

    d = data_utils.process_data_for_transformers(text, bpetokenizer, tokenizer,
                                                 0)
    d = {
        "ids": torch.tensor([d['ids']], dtype=torch.long),
        "mask": torch.tensor([d['mask']], dtype=torch.long),
        "token_type_ids": torch.tensor([d['token_type_ids']], dtype=torch.long)
    }

    try:
        orig_tokens = [0] + bpetokenizer.encode(text).ids + [2]
        orig_tokens = [bpetokenizer.id_to_token(j) for j in orig_tokens]
    except:
        orig_tokens = tokenizer.tokenize(text, add_special_tokens=True)

    model.eval()
    if multiclass:
        preds_proba = torch.sigmoid(
            model(d["ids"], d["mask"],
                  d["token_type_ids"])).detach().cpu().numpy()
        preds = preds_proba.argmax(-1)
        preds_proba = preds_proba[0][preds[0][0]]
        predicted_class = idx2label[preds[0][0]]
    else:
        preds_proba = torch.sigmoid(
            model(d["ids"], d["mask"],
                  d["token_type_ids"])).detach().cpu().numpy()
        preds = np.round(preds_proba)
        preds_proba = preds_proba[0][0]
        predicted_class = idx2label[preds[0][0]]

    lig = LayerIntegratedGradients(model, model.base_model.embeddings)

    reference_indices = [0] + [1] * (d["ids"].shape[1] - 2) + [2]
    reference_indices = torch.tensor([reference_indices], dtype=torch.long)

    attributions_ig, delta = lig.attribute(inputs=d["ids"],baselines=reference_indices,additional_forward_args=(d["mask"],d["token_type_ids"]), \
                                           return_convergence_delta=True)

    attributions = attributions_ig.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.detach().cpu().numpy()

    visualization.visualize_text([
        visualization.VisualizationDataRecord(word_attributions=attributions,
                                              pred_prob=preds_proba,
                                              pred_class=predicted_class,
                                              true_class=predicted_class,
                                              attr_class=predicted_class,
                                              attr_score=attributions.sum(),
                                              raw_input=orig_tokens,
                                              convergence_score=delta)
    ])
示例#14
0
def interpret_main(text, label):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = load_model(
        '/Users/andrewmendez1/Documents/ai-ml-challenge-2020/data/Finetune BERT oversampling 8_16_2020/Model_1_4_0/model.pt',
        device)

    def predict(inputs):
        #print('model(inputs): ', model(inputs))
        return model.encoder(inputs)[0]

    def custom_forward(inputs):
        preds = predict(inputs)
        return torch.softmax(preds, dim=1)[:, 0]

    # load tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    ref_token_id = tokenizer.pad_token_id  # A token used for generating token reference
    sep_token_id = tokenizer.sep_token_id  # A token used as a separator between question and text and it is also added to the end of the text.
    cls_token_id = tokenizer.cls_token_id  # A token used for prepending to the concatenated question-text word sequence
    hook = model.encoder.bert.embeddings.register_forward_hook(save_act)
    hook.remove()

    input_ids, ref_input_ids, sep_id = construct_input_ref_pair(
        text, tokenizer, device, ref_token_id, sep_token_id, cls_token_id)
    token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(
        input_ids, device, sep_id)
    position_ids, ref_position_ids = construct_input_ref_pos_id_pair(
        input_ids, device)
    attention_mask = construct_attention_mask(input_ids)

    # text = "the exclusion of implied warranties is not permitted by some the above exclusion may not apply to"# label 0

    lig = LayerIntegratedGradients(custom_forward,
                                   model.encoder.bert.embeddings)
    # attributions_main, delta_main = lig.attribute(inputs=input_ids,baselines=ref_input_ids,return_convergence_delta=True,n_steps=30)
    t0 = time()
    attributions, delta = lig.attribute(
        inputs=input_ids,
        baselines=ref_input_ids,
        # n_steps=7000,
        # internal_batch_size=5,
        return_convergence_delta=True,
        n_steps=300)
    st.write("Time to complete interpretation: {} seconds".format(time() - t0))
    # print("Time in {} minutes".format( (time()-t0)/60 ))
    attributions_sum = summarize_attributions(attributions)

    all_tokens = tokenizer.convert_ids_to_tokens(
        input_ids[0].detach().tolist())
    top_tokens, values, indicies = get_topk_attributed_tokens(attributions_sum,
                                                              all_tokens,
                                                              k=7)
    st.subheader("Top Tokens that the Model decided Unacceptability")
    import numpy as np
    plt.figure(figsize=(12, 6))
    x_pos = np.arange(len(values))
    plt.bar(x_pos, values.detach().numpy(), align='center')
    plt.xticks(x_pos, top_tokens, wrap=True)
    plt.xlabel("Tokens")
    plt.title(
        "Top 5 Tokens that made the model classify clause as unacceptable")
    st.pyplot()

    st.subheader(
        "Detailed Table showing Attribution Score to each word in clause")
    st.write(" ")
    st.write(
        "Positive Attributions mean that the words/tokens were \"positively\" attributed to the models's prediction."
    )
    st.write(
        "Negative Attributions mean that the words/tokens were \"negatively\" attributed to the models's prediction."
    )

    # res = ['{}({}) {:.3f}'.format(token, str(i),attributions_sum[i]) for i, token in enumerate(all_tokens)]
    df = pd.DataFrame({
        'Words': all_tokens,
        'Attributions': attributions_sum.detach().numpy()
    })
    st.table(df)
    score = predict(input_ids)
    score_vis = viz.VisualizationDataRecord(
        attributions_sum,
        torch.softmax(score, dim=1)[0][0],
        torch.argmax(torch.softmax(score, dim=1)[0]), label, text,
        attributions_sum.sum(), all_tokens, delta)
    print('\033[1m', 'Visualization For Score', '\033[0m')
    # from IPython.display import display, HTML, Image
    # viz.visualize_text([score_vis])
    # st.write(display(Image(viz.visualize_text([score_vis])) ) )

    # open('output.png', 'wb').write(im.data)
    # st.pyplot()


# text= "this license shall be effective until company in its sole and absolute at any time and for any or no disable the or suspend or terminate this license and the rights afforded to you with or without prior notice or other action by upon the termination of this you shall cease all use of the app and uninstall the company will not be liable to you or any third party for or damages of any sort as a result of terminating this license in accordance with its and termination of this license will be without prejudice to any other right or remedy company may now or in the these obligations survive termination of this"
# # label=1
# label = "?"
# main(text,label)
示例#15
0
def add_attributions_to_visualizer(attributions,
                                   text,
                                   token_ids,
                                   pred,
                                   pred_ind,
                                   label,
                                   delta,
                                   vis_data_records,
                                   top_k=10):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()
    node_ids = [id + len(docs) for id in token_ids]
    tokens_attributions = attributions[node_ids]
    sorted_index_h = np.argsort(-attributions)
    sorted_index_l = np.argsort(attributions)
    top_k_index_h = sorted_index_h[:top_k]
    top_k_index_l = sorted_index_l[:top_k]
    other_tokens_index_h = np.array([
        index for index in top_k_index_h
        if index >= doc_size and index not in node_ids
    ])
    other_tokens_index_l = np.array([
        index for index in top_k_index_l
        if index >= doc_size and index not in node_ids
    ])
    other_tokens_attributions_h = {
        id_word_map[index - doc_size]: float(attributions[index])
        for index in other_tokens_index_h if index not in node_ids
    }
    other_tokens_attributions_l = {
        id_word_map[index - doc_size]: float(attributions[index])
        for index in other_tokens_index_l if index not in node_ids
    }
    doc_index_h = np.array(
        [index for index in top_k_index_h if index < doc_size])
    doc_index_l = np.array(
        [index for index in top_k_index_l if index < doc_size])
    doc_attributions_h = {
        docs[index]: float(attributions[index])
        for index in doc_index_h
    }
    doc_attributions_l = {
        docs[index]: float(attributions[index])
        for index in doc_index_l
    }
    print('all tokens with attributions in this doc')
    print(
        json.dumps(
            {
                token: float(attribution)
                for token, attribution in zip(text, tokens_attributions)
            },
            indent=4,
            ensure_ascii=False))
    print(
        f'other words -{len(other_tokens_attributions_h)}- and attribution in top {top_k} nodes'
    )
    print(json.dumps(other_tokens_attributions_h, indent=4,
                     ensure_ascii=False))
    print(json.dumps(other_tokens_attributions_l, indent=4,
                     ensure_ascii=False))
    print(
        f'docs -{len(doc_attributions_h)}- and attribution in top {top_k} nodes'
    )
    print(json.dumps(doc_attributions_h, indent=4, ensure_ascii=False))
    print(json.dumps(doc_attributions_l, indent=4, ensure_ascii=False))
    vis_data_records.append(
        visualization.VisualizationDataRecord(tokens_attributions, pred,
                                              pred_ind, label, '1',
                                              tokens_attributions.sum(), text,
                                              delta))
    node_indexes = np.concatenate([top_k_index_h, top_k_index_l, node_ids])
    node_texts = [
        str(index) if index < doc_size else id_word_map[index - doc_size]
        for index in node_indexes
    ]

    node_weights = attributions[node_indexes]
    plot_with_networkx(node_indexes.tolist(), node_texts,
                       node_weights.tolist())