def initialize(self):
     print("initial tokenizer...")
     self.tokenizer = BertTokenizer.from_pretrained(self.tokenizer_path)
     self.ref_token_id = self.tokenizer.pad_token_id
     self.sep_token_id = self.tokenizer.sep_token_id
     self.cls_token_id = self.tokenizer.cls_token_id
     print("initial inference model...")
     self.model = BertForSequenceClassification.from_pretrained(
         self.model_path, num_labels=2).cpu().eval()
     self.model.zero_grad()
     print("initial lac model...")
     self.lac = LAC(mode="seg")
     self.lac.load_customization(self.lac_dict_path, sep="\t")
     print("initial interpretable embedding layers ...")
     self.interpretable_embedding1 = configure_interpretable_embedding_layer(
         self.model, 'bert.embeddings.word_embeddings')
     self.interpretable_embedding2 = configure_interpretable_embedding_layer(
         self.model, 'bert.embeddings.token_type_embeddings')
     self.interpretable_embedding3 = configure_interpretable_embedding_layer(
         self.model, 'bert.embeddings.position_embeddings')
     remove_interpretable_embedding_layer(self.model,
                                          self.interpretable_embedding1)
     remove_interpretable_embedding_layer(self.model,
                                          self.interpretable_embedding2)
     remove_interpretable_embedding_layer(self.model,
                                          self.interpretable_embedding3)
Exemplo n.º 2
0
 def __post_init__(self) -> None:
     self.special_token_mask = [
         self.tokenizer.unk_token_id, self.tokenizer.sep_token_id,
         self.tokenizer.pad_token_id, self.tokenizer.cls_token_id
     ]
     self.candidate_token_mask, self.candidate_token_map = gen_candidate_mask(
         self.tokenizer)
     self.interpretable_embedding = configure_interpretable_embedding_layer(
         self.model, 'albert.embeddings.word_embeddings')
     self.model_perf_tups = {
         'tweets':
         load_cache(f'{self.config.experiment.dirs.model_cache_dir}/'
                    f'{constants.TWEET_MODEL_PERF_CACHE_NAME}'),
         'nontweets':
         load_cache(f'{self.config.experiment.dirs.model_cache_dir}/'
                    f'{constants.NONTWEET_MODEL_PERF_CACHE_NAME}'),
         'global':
         load_cache(f'{self.config.experiment.dirs.model_cache_dir}/'
                    f'{constants.GLOBAL_MODEL_PERF_CACHE_NAME}')
     }
     self.stmt_embed_dict = load_cache(
         f'{self.config.experiment.dirs.model_cache_dir}/'
         f'{constants.STMT_EMBED_CACHE_NAME}')
     if self.stmt_embed_dict:
         self.max_pred_idx = torch.tensor(
             np.argmax(np.asarray(self.stmt_embed_dict['preds'])))
         self.min_pred_idx = torch.tensor(
             np.argmin(np.asarray(self.stmt_embed_dict['preds'])))
     self.vis_data_records, self.ext_vis_data_records, self.ss_image_paths = [], [], []
Exemplo n.º 3
0
 def configure_embedding_layer(self):
     interp_embs = []
     for layer in self.model:
         if isinstance(layer, nn.Embedding):
             interp_embs.append(
                 configure_interpretable_embedding_layer(self.model, layer))
             break
Exemplo n.º 4
0
    def predict_minibatch(self,
                          inputs: List[JsonDict],
                          config=None) -> List[JsonDict]:
        """
        batch size set to 1 for simplicity, to use batch size greater than one, will need
        to use self.tokenizer.batch_encode_plus as in the LIT examples
        :param inputs: JSON of sentence and token to interpret
        :param config:
        :return: prediction output aligned with spec
        """
        mask_token = '[MASK]'
        sentence = inputs[0]['Sentence']
        interpret_token_id = inputs[0]['Token Index to Explain']
        tokens = ['[CLS]'] + self.tokenizer.tokenize(sentence) + ['[SEP]']
        input_ids = [self.tokenizer.convert_tokens_to_ids(tokens)]
        input_mask = [[1] * len(input_ids[0])]
        input_ids_tensor = torch.tensor(input_ids, dtype=torch.long)
        input_mask_tensor = torch.tensor(input_mask, dtype=torch.long)
        # Needed for calculating grad based on embeddings
        interpretable_embedding = configure_interpretable_embedding_layer(
            self.model, 'bert.embeddings.word_embeddings')
        input_embeddings = interpretable_embedding.indices_to_embeddings(
            input_ids_tensor)
        model_input = {
            "inputs_embeds": input_embeddings,
            "attention_mask": input_mask_tensor
        }
        model_output = self.model(**model_input)
        logits, embs, unused_attentions = model_output[:3]
        logits_ndarray = logits.detach().cpu().numpy()
        example_preds = np.argmax(logits_ndarray, axis=2)
        confidences = torch.softmax(torch.from_numpy(logits_ndarray),
                                    dim=2).detach().cpu().numpy()
        label_map = {i: label for i, label in enumerate(self.LABELS)}
        predictions = [label_map[pred] for pred in example_preds[0]]
        outputs = {}
        for i, attention_layer in enumerate(unused_attentions):
            outputs[f'layer_{i}/attention'] = attention_layer[0].detach().cpu(
            ).numpy().copy()

        # TODO Currently LIT lime explainer does not support targeting a specific token, until that's fixed,
        #  we explain the first non-O index if there's one, or the first token (after [CLS]).
        if interpret_token_id < 0 or mask_token in sentence:
            scalar_output = np.where(example_preds[0] != 0)[0]
            token_index = scalar_output[0] if len(scalar_output > 1) else 1
        else:
            # TODO When LIT lime explainer is configurable, we'll set the token_index from the UI
            token_index = interpret_token_id

        outputs['tokens'] = tokens
        outputs['bio_tags'] = predictions
        grad = torch.autograd.grad(torch.unbind(logits[0][token_index]),
                                   embs[0])
        outputs['grads'] = grad[0][0].detach().cpu().numpy()
        outputs['probas'] = confidences[0][token_index]
        outputs['token_ids'] = list(range(0, len(tokens)))

        remove_interpretable_embedding_layer(self.model,
                                             interpretable_embedding)
        yield outputs
Exemplo n.º 5
0
def attribute_predict(collate_fn, model_args, dataset, attribution_method,
                      model, target, embedding_name):
    predicted_logits, attributions, token_ids_all, true_target = [], [], [], []

    dl = torch.utils.data.DataLoader(batch_size=model_args.batch_size,
                                     dataset=dataset,
                                     collate_fn=collate_fn)

    if not isinstance(attribution_method, LimeBase):
        interpretable_embedding = configure_interpretable_embedding_layer(model,embedding_name)

    for batch in tqdm(dl):
        token_ids = batch[0]
        if model_args.model == 'lstm':
            additional_args = (batch[-1],)
        elif model_args.model == 'transformer':
            additional_args = (token_ids != 0,)
        else:
            additional_args = None
        if isinstance(attribution_method, LimeBase):
            inputs = token_ids
            instance_attribution = attribution_method.attribute(token_ids,
                                                                n_perturb_samples=100,
                                                                additional_forward_args=additional_args,
                                                                target=target)
        else:
            inputs = interpretable_embedding.indices_to_embeddings(token_ids)
            instance_attribution = attribution_method.attribute(inputs,
                                                                additional_forward_args=additional_args,
                                                                target=target)
            instance_attribution = summarize_attributions(instance_attribution,
                                                          type='mean').detach().cpu()

        predicted_logits += model(inputs, additional_args[0] if additional_args else None).detach().cpu().numpy().tolist()
        attributions += instance_attribution
        token_ids_all += token_ids.detach().cpu().numpy().tolist()
        true_target += batch[1].detach().cpu().numpy().tolist()

    if not isinstance(attribution_method, LimeBase):
        remove_interpretable_embedding_layer(model, interpretable_embedding)

    return predicted_logits, attributions, token_ids_all, true_target
Exemplo n.º 6
0
#                 torch.argmax(x_pert,dim=1).item(),
#                 goal_func_result.output,
#                 2,
#                 atts_pert.sum().detach(),
#                 all_tokens_pert[:45],
#                 delta_pert)

# post = {
#             "type": "captum",
#             "input_string": input_text,
#             "model_name": model_name,
#             "recipe_name": recipe_name,
#             "output_string": output_text
#         }

interpretable_embedding = configure_interpretable_embedding_layer(
    clone.model, 'bert.embeddings')
ref_token_id = original_tokenizer.tokenizer.pad_token_id
sep_token_id = original_tokenizer.tokenizer.sep_token_id
cls_token_id = original_tokenizer.tokenizer.cls_token_id


def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions


def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

Exemplo n.º 7
0
def captum_heatmap_interactive(request):
    if request.method == 'POST':
        STORED_POSTS = request.session.get("TextAttackResult")
        form = CustomData(request.POST)
        if form.is_valid():
            input_text, model_name, recipe_name = form.cleaned_data[
                'input_text'], form.cleaned_data[
                    'model_name'], form.cleaned_data['recipe_name']
            found = False
            if STORED_POSTS:
                JSON_STORED_POSTS = json.loads(STORED_POSTS)
                for idx, el in enumerate(JSON_STORED_POSTS):
                    if el["type"] == "heatmap" and el[
                            "input_string"] == input_text:
                        tmp = JSON_STORED_POSTS.pop(idx)
                        JSON_STORED_POSTS.insert(0, tmp)
                        found = True
                        break

                if found:
                    request.session["TextAttackResult"] = json.dumps(
                        JSON_STORED_POSTS[:10])
                    return HttpResponseRedirect(reverse('webdemo:index'))

            original_model = transformers.AutoModelForSequenceClassification.from_pretrained(
                "textattack/" + model_name)
            original_tokenizer = textattack.models.tokenizers.AutoTokenizer(
                "textattack/" + model_name)
            model = textattack.models.wrappers.HuggingFaceModelWrapper(
                original_model, original_tokenizer)

            device = torch.device(
                "cuda:2" if torch.cuda.is_available() else "cpu")
            clone = deepcopy(model)
            clone.model.to(device)

            def calculate(input_ids, token_type_ids, attention_mask):
                return clone.model(input_ids, token_type_ids,
                                   attention_mask)[0]

            interpretable_embedding = configure_interpretable_embedding_layer(
                clone.model, 'bert.embeddings')
            ref_token_id = original_tokenizer.tokenizer.pad_token_id
            sep_token_id = original_tokenizer.tokenizer.sep_token_id
            cls_token_id = original_tokenizer.tokenizer.cls_token_id

            def summarize_attributions(attributions):
                attributions = attributions.sum(dim=-1).squeeze(0)
                attributions = attributions / torch.norm(attributions)
                return attributions

            def construct_attention_mask(input_ids):
                return torch.ones_like(input_ids)

            def construct_input_ref_pos_id_pair(input_ids):
                seq_length = input_ids.size(1)
                position_ids = torch.arange(seq_length,
                                            dtype=torch.long,
                                            device=device)
                ref_position_ids = torch.zeros(seq_length,
                                               dtype=torch.long,
                                               device=device)

                position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
                ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(
                    input_ids)
                return position_ids, ref_position_ids

            def squad_pos_forward_func(inputs,
                                       token_type_ids=None,
                                       attention_mask=None):
                pred = calculate(inputs, token_type_ids, attention_mask)
                return pred.max(1).values

            def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
                seq_len = input_ids.size(1)
                token_type_ids = torch.tensor(
                    [[0 if i <= sep_ind else 1 for i in range(seq_len)]],
                    device=device)
                ref_token_type_ids = torch.zeros_like(token_type_ids,
                                                      device=device)  # * -1
                return token_type_ids, ref_token_type_ids

            input_text_ids = original_tokenizer.tokenizer.encode(
                input_text, add_special_tokens=False)
            input_ids = [cls_token_id] + input_text_ids + [sep_token_id]
            input_ids = torch.tensor([input_ids], device=device)

            position_ids, ref_position_ids = construct_input_ref_pos_id_pair(
                input_ids)
            ref_input_ids = [
                cls_token_id
            ] + [ref_token_id] * len(input_text_ids) + [sep_token_id]
            ref_input_ids = torch.tensor([ref_input_ids], device=device)

            token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(
                input_ids, len(input_text_ids))
            attention_mask = torch.ones_like(input_ids)

            input_embeddings = interpretable_embedding.indices_to_embeddings(
                input_ids,
                token_type_ids=token_type_ids,
                position_ids=position_ids)
            ref_input_embeddings = interpretable_embedding.indices_to_embeddings(
                ref_input_ids,
                token_type_ids=ref_token_type_ids,
                position_ids=ref_position_ids)

            layer_attrs_start = []

            for i in range(len(clone.model.bert.encoder.layer)):
                lc = LayerConductance(squad_pos_forward_func,
                                      clone.model.bert.encoder.layer[i])
                layer_attributions_start = lc.attribute(
                    inputs=input_embeddings,
                    baselines=ref_input_embeddings,
                    additional_forward_args=(token_type_ids,
                                             attention_mask))[0]

                layer_attrs_start.append(
                    summarize_attributions(
                        layer_attributions_start).cpu().detach().tolist())

            all_tokens = original_tokenizer.tokenizer.convert_ids_to_tokens(
                input_ids[0])

            fig, ax = plt.subplots(figsize=(15, 5))
            xticklabels = all_tokens
            yticklabels = list(range(1, 13))
            ax = sns.heatmap(np.array(layer_attrs_start),
                             xticklabels=xticklabels,
                             yticklabels=yticklabels,
                             linewidth=0.2)
            plt.xlabel('Tokens')
            plt.ylabel('Layers')

            buf = io.BytesIO()
            fig.savefig(buf, format='png')
            buf.seek(0)
            bufferString = base64.b64encode(buf.read())
            imageUri = urllib.parse.quote(bufferString)

            post = {
                "type": "heatmap",
                "input_string": input_text,
                "model_name": model_name,
                "recipe_name": recipe_name,
                "image": imageUri,
            }

            if STORED_POSTS:
                JSON_STORED_POSTS = json.loads(STORED_POSTS)
                JSON_STORED_POSTS.insert(0, post)
                request.session["TextAttackResult"] = json.dumps(
                    JSON_STORED_POSTS[:10])
            else:
                request.session["TextAttackResult"] = json.dumps([post])

            return HttpResponseRedirect(reverse('webdemo:index'))

        else:
            return HttpResponseNotFound('Failed')

        return HttpResponse('Success')

    return HttpResponseNotFound('<h1>Not Found</h1>')
Exemplo n.º 8
0
def predict(input_abstract: str, input_title: str, input_keywords: str,
            model: nn.Module, vectors: Vectors, output_vectors: Dictionary,
            device: torch.device):
    """
    :param input_abstract:
    :param input_title:
    :param input_keywords:
    :param model:
    :param vectors:
    :param output_vectors:
    :param device:
    :return:
    """
    # Prepare for visualization
    vis_data_records_ig = []

    # Interpret sentence
    try:
        interpretable_embedding_abstracts = configure_interpretable_embedding_layer(
            model, 'embedding_abstracts')
        interpretable_embedding_titles = configure_interpretable_embedding_layer(
            model, 'embedding_titles')
        interpretable_embedding_keywords = configure_interpretable_embedding_layer(
            model, 'embedding_keywords')
        ig = IntegratedGradients(model)
    except:
        exit(1)

    print(f"Created Interpretable Layers: {datetime.datetime.now()}")

    interpret_sentence(
        model=model,
        input_abstract=input_abstract,
        input_title=input_title,
        input_keywords=input_keywords,
        vectors=vectors,
        interpretable_embedding_abstracts=interpretable_embedding_abstracts,
        interpretable_embedding_titles=interpretable_embedding_titles,
        interpretable_embedding_keywords=interpretable_embedding_keywords,
        ig=ig,
        vis_data_records_ig=vis_data_records_ig,
        output_vectors=output_vectors,
        device=device)

    print(f"Interpreted: {datetime.datetime.now()}")

    # Show interpretations
    #print(build_html(vis_data_records_ig))
    json_data = build_json(vis_data_records_ig)

    print(f"Built JSON: {datetime.datetime.now()}")

    remove_interpretable_embedding_layer(model,
                                         interpretable_embedding_abstracts)
    remove_interpretable_embedding_layer(model, interpretable_embedding_titles)
    remove_interpretable_embedding_layer(model,
                                         interpretable_embedding_keywords)

    print(f"Removed Layers: {datetime.datetime.now()}")

    return json_data
Exemplo n.º 9
0
    batch_labels_ndarray = batch_labels.detach().cpu().numpy()
    if preds is None:
        preds = logits_ndarray
        out_label_ids = inputs["labels"].detach().cpu().numpy()
    else:
        preds = np.append(preds, logits_ndarray, axis=0)
        out_label_ids = np.append(out_label_ids,
                                  inputs["labels"].detach().cpu().numpy(),
                                  axis=0)

    example_preds = np.argmax(logits_ndarray, axis=2)

    if np.any(batch_labels_ndarray[0], where=batch_labels_ndarray[0] != 0):
        # interpretable_embedding = configure_interpretable_embedding_layer(deeplift_model,
        #                                                                   'model.bert.embeddings.word_embeddings')
        interpretable_embedding = configure_interpretable_embedding_layer(
            model, 'bert.embeddings.word_embeddings')
        input_embeddings = interpretable_embedding.indices_to_embeddings(
            input_ids)

        for token_index in np.where(batch_labels_ndarray[0] != 0)[0]:
            if out_label_ids[example_index][token_index] != pad_token_label_id:
                label_id = example_preds[0][token_index]
                true_label = batch_labels[0][token_index].item()
                target = (token_index, label_id)
                logger.info(
                    f'Calculating attribution for label {label_id} at index {token_index}'
                )
                attribution_start = time.time()

                # attributions, delta = explainer.attribute(input_embeddings, target=target,
                #                                           additional_forward_args=batch_labels,
def generate(sample, lac):
    model.zero_grad()

    input_ids, ref_input_ids, sep_id = construct_input_ref_pair(
        sample.text, ref_token_id, sep_token_id, cls_token_id)
    token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(
        input_ids, sep_id)
    position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
    attention_mask = construct_attention_mask(input_ids)

    indices = input_ids[0].detach().tolist()
    tokens = tokenizer.convert_ids_to_tokens(indices)

    input_ids = input_ids.to(device)
    ref_input_ids = ref_input_ids.to(device)
    token_type_ids = token_type_ids.to(device)
    ref_token_type_ids = ref_token_type_ids.to(device)
    position_ids = position_ids.to(device)
    ref_position_ids = ref_position_ids.to(device)
    attention_mask = attention_mask.to(device)

    logits = forward_func(input_ids, \
              token_type_ids=token_type_ids, \
              position_ids=position_ids, \
              attention_mask=attention_mask)
    pred = int(logits.argmax(1))
    score = float(logits[0][pred])

    interpretable_embedding1 = configure_interpretable_embedding_layer(
        model, 'bert.embeddings.word_embeddings')
    interpretable_embedding2 = configure_interpretable_embedding_layer(
        model, 'bert.embeddings.token_type_embeddings')
    interpretable_embedding3 = configure_interpretable_embedding_layer(
        model, 'bert.embeddings.position_embeddings')
    (input_embed, ref_input_embed), (token_type_ids_embed, ref_token_type_ids_embed), (position_ids_embed, ref_position_ids_embed) = construct_bert_sub_embedding(input_ids, ref_input_ids, \
                                         token_type_ids=token_type_ids, ref_token_type_ids=ref_token_type_ids, \
                                         position_ids=position_ids, ref_position_ids=ref_position_ids)

    lig = IntegratedGradients(forward_func)
    attributions, delta = lig.attribute(
        inputs=(input_embed, token_type_ids_embed, position_ids_embed),
        baselines=(ref_input_embed, ref_token_type_ids_embed,
                   ref_position_ids_embed),
        target=sample.label,
        additional_forward_args=(attention_mask),
        return_convergence_delta=True)

    input_ids = input_ids.cpu()
    ref_input_ids = ref_input_ids.cpu()
    token_type_ids = token_type_ids.cpu()
    ref_token_type_ids = ref_token_type_ids.cpu()
    position_ids = position_ids.cpu()
    ref_position_ids = ref_position_ids.cpu()
    attention_mask = attention_mask.cpu()
    input_embed = input_embed.cpu()
    ref_input_embed = ref_input_embed.cpu()
    token_type_ids_embed = token_type_ids_embed.cpu()
    ref_token_type_ids_embed = ref_token_type_ids_embed.cpu()
    position_ids_embed = position_ids_embed.cpu()
    ref_position_ids_embed = ref_position_ids_embed.cpu()

    torch.cuda.empty_cache()

    _, attribution_words = word_level_spline(
        tokens,
        summarize_attributions(attributions[0]).cpu().detach().numpy(), lac)
    _, attribution_position = word_level_spline(
        tokens,
        summarize_attributions(attributions[2]).cpu().detach().numpy(), lac)
    words = [each[0] for each in attribution_words]
    attribution_merge = [
        attribution_words[i][1] + attribution_position[i][1]
        for i in range(len(attribution_words))
    ]
    remove_interpretable_embedding_layer(model, interpretable_embedding1)
    remove_interpretable_embedding_layer(model, interpretable_embedding2)
    remove_interpretable_embedding_layer(model, interpretable_embedding3)

    return Result(words=words, label=pred, attribution=attribution_merge)
model_path = "checkpoints/roberta24"

model = BertForSequenceClassification.from_pretrained(model_path, num_labels=2)
model.to(device)
model.eval()
model.zero_grad()

tokenizer_path = "hfl/chinese-roberta-wwm-ext"
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
ref_token_id = tokenizer.pad_token_id
sep_token_id = tokenizer.sep_token_id
cls_token_id = tokenizer.cls_token_id

print("prepare interpretable embedding...")
interpretable_embedding1 = configure_interpretable_embedding_layer(
    model, 'bert.embeddings.word_embeddings')
interpretable_embedding2 = configure_interpretable_embedding_layer(
    model, 'bert.embeddings.token_type_embeddings')
interpretable_embedding3 = configure_interpretable_embedding_layer(
    model, 'bert.embeddings.position_embeddings')
remove_interpretable_embedding_layer(model, interpretable_embedding1)
remove_interpretable_embedding_layer(model, interpretable_embedding2)
remove_interpretable_embedding_layer(model, interpretable_embedding3)


def predict(inputs,
            token_type_ids=None,
            position_ids=None,
            attention_mask=None):
    return model(
        inputs,
Exemplo n.º 12
0
def generate_saliency(model_path, saliency_path, saliency, aggregation):
    checkpoint = torch.load(model_path,
                            map_location=lambda storage, loc: storage)
    model_args = Namespace(**checkpoint['args'])
    if args.model == 'lstm':
        model = LSTM_MODEL(tokenizer,
                           model_args,
                           n_labels=checkpoint['args']['labels']).to(device)
        model.load_state_dict(checkpoint['model'])
    elif args.model == 'trans':
        transformer_config = BertConfig.from_pretrained(
            'bert-base-uncased', num_labels=model_args.labels)
        model_cp = BertForSequenceClassification.from_pretrained(
            'bert-base-uncased', config=transformer_config).to(device)
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, loc: storage)
        model_cp.load_state_dict(checkpoint['model'])
        model = BertModelWrapper(model_cp)
    else:
        model = CNN_MODEL(tokenizer,
                          model_args,
                          n_labels=checkpoint['args']['labels']).to(device)
        model.load_state_dict(checkpoint['model'])

    model.train()

    pad_to_max = False
    if saliency == 'deeplift':
        ablator = DeepLift(model)
    elif saliency == 'guided':
        ablator = GuidedBackprop(model)
    elif saliency == 'sal':
        ablator = Saliency(model)
    elif saliency == 'inputx':
        ablator = InputXGradient(model)
    elif saliency == 'occlusion':
        ablator = Occlusion(model)

    coll_call = get_collate_fn(dataset=args.dataset, model=args.model)

    return_attention_masks = args.model == 'trans'

    collate_fn = partial(coll_call,
                         tokenizer=tokenizer,
                         device=device,
                         return_attention_masks=return_attention_masks,
                         pad_to_max_length=pad_to_max)
    test = get_dataset(path=args.dataset_dir,
                       mode=args.split,
                       dataset=args.dataset)
    batch_size = args.batch_size if args.batch_size != None else \
        model_args.batch_size
    test_dl = DataLoader(batch_size=batch_size,
                         dataset=test,
                         shuffle=False,
                         collate_fn=collate_fn)

    # PREDICTIONS
    predictions_path = model_path + '.predictions'
    if not os.path.exists(predictions_path):
        predictions = defaultdict(lambda: [])
        for batch in tqdm(test_dl, desc='Running test prediction... '):
            if args.model == 'trans':
                logits = model(batch[0],
                               attention_mask=batch[1],
                               labels=batch[2].long())
            else:
                logits = model(batch[0])
            logits = logits.detach().cpu().numpy().tolist()
            predicted = np.argmax(np.array(logits), axis=-1)
            predictions['class'] += predicted.tolist()
            predictions['logits'] += logits

        with open(predictions_path, 'w') as out:
            json.dump(predictions, out)

    # COMPUTE SALIENCY
    if saliency != 'occlusion':
        embedding_layer_name = 'model.bert.embeddings' if args.model == \
                                                          'trans' else \
            'embedding'
        interpretable_embedding = configure_interpretable_embedding_layer(
            model, embedding_layer_name)

    class_attr_list = defaultdict(lambda: [])
    token_ids = []
    saliency_flops = []

    for batch in tqdm(test_dl, desc='Running Saliency Generation...'):
        if args.model == 'cnn':
            additional = None
        elif args.model == 'trans':
            additional = (batch[1], batch[2])
        else:
            additional = batch[-1]

        token_ids += batch[0].detach().cpu().numpy().tolist()
        if saliency != 'occlusion':
            input_embeddings = interpretable_embedding.indices_to_embeddings(
                batch[0])

        if not args.no_time:
            high.start_counters([
                events.PAPI_FP_OPS,
            ])
        for cls_ in range(checkpoint['args']['labels']):
            if saliency == 'occlusion':
                attributions = ablator.attribute(
                    batch[0],
                    sliding_window_shapes=(args.sw, ),
                    target=cls_,
                    additional_forward_args=additional)
            else:
                attributions = ablator.attribute(
                    input_embeddings,
                    target=cls_,
                    additional_forward_args=additional)

            attributions = summarize_attributions(
                attributions, type=aggregation, model=model,
                tokens=batch[0]).detach().cpu().numpy().tolist()
            class_attr_list[cls_] += [[_li for _li in _l]
                                      for _l in attributions]

        if not args.no_time:
            saliency_flops.append(
                sum(high.stop_counters()) / batch[0].shape[0])

    if saliency != 'occlusion':
        remove_interpretable_embedding_layer(model, interpretable_embedding)

    # SERIALIZE
    print('Serializing...', flush=True)
    with open(saliency_path, 'w') as out:
        for instance_i, _ in enumerate(test):
            saliencies = []
            for token_i, token_id in enumerate(token_ids[instance_i]):
                token_sal = {'token': tokenizer.ids_to_tokens[token_id]}
                for cls_ in range(checkpoint['args']['labels']):
                    token_sal[int(
                        cls_)] = class_attr_list[cls_][instance_i][token_i]
                saliencies.append(token_sal)

            out.write(json.dumps({'tokens': saliencies}) + '\n')
            out.flush()

    return saliency_flops
    def predict(self, text):
        self.model.zero_grad()
        input_ids, ref_input_ids, sep_id = self.construct_input_ref_pair(
            text, self.ref_token_id, self.sep_token_id, self.cls_token_id)
        token_type_ids, ref_token_type_ids = self.construct_input_ref_token_type_pair(
            input_ids, sep_id)
        position_ids, ref_position_ids = self.construct_input_ref_pos_id_pair(
            input_ids)
        attention_mask = self.construct_attention_mask(input_ids)
        indices = input_ids[0].detach().tolist()
        tokens = self.tokenizer.convert_ids_to_tokens(indices)
        logits = self.forward_func(input_ids, \
              token_type_ids=token_type_ids, \
              position_ids=position_ids, \
              attention_mask=attention_mask)
        pred = int(logits.argmax(1))
        logits = logits.tolist()[0]

        self.interpretable_embedding1 = configure_interpretable_embedding_layer(
            self.model, 'bert.embeddings.word_embeddings')
        self.interpretable_embedding2 = configure_interpretable_embedding_layer(
            self.model, 'bert.embeddings.token_type_embeddings')
        self.interpretable_embedding3 = configure_interpretable_embedding_layer(
            self.model, 'bert.embeddings.position_embeddings')
        (input_embed, ref_input_embed), (token_type_ids_embed, ref_token_type_ids_embed), (position_ids_embed, ref_position_ids_embed) = self.construct_bert_sub_embedding(input_ids, ref_input_ids, \
                                                 token_type_ids=token_type_ids, ref_token_type_ids=ref_token_type_ids, \
                                                 position_ids=position_ids, ref_position_ids=ref_position_ids)

        lig = IntegratedGradients(self.forward_func)

        attr = []
        for label in range(len(logits)):
            attributions, delta = lig.attribute(
                inputs=(input_embed, token_type_ids_embed, position_ids_embed),
                baselines=(ref_input_embed, ref_token_type_ids_embed,
                           ref_position_ids_embed),
                target=label,
                additional_forward_args=(attention_mask),
                return_convergence_delta=True)
            _, attribution_words = self.word_level_spline(
                tokens,
                self.summarize_attributions(
                    attributions[0]).cpu().detach().numpy(), self.lac)
            _, attribution_position = self.word_level_spline(
                tokens,
                self.summarize_attributions(
                    attributions[2]).cpu().detach().numpy(), self.lac)
            if len(attribution_words) != 0:
                words = [each[0] for each in attribution_words]
                attribution_merge = [
                    attribution_words[i][1] + attribution_position[i][1]
                    for i in range(len(attribution_words))
                ]
                range_limit = np.max(np.abs(attribution_merge))
                attribution_merge /= range_limit
                attr.append(list(attribution_merge))
            else:
                words = []
                attr.append([])

        remove_interpretable_embedding_layer(self.model,
                                             self.interpretable_embedding1)
        remove_interpretable_embedding_layer(self.model,
                                             self.interpretable_embedding2)
        remove_interpretable_embedding_layer(self.model,
                                             self.interpretable_embedding3)

        return words, logits, pred, attr
Exemplo n.º 14
0
                                 position_ids=position_ids,
                                 attention_mask=attention_mask)
    pred = pred[position]
    return pred.max(1).values


# Optional[int]
ref_token_id = tokenizer.pad_token_id

# Optional[int]
sep_token_id = tokenizer.sep_token_id

# Optional[int]
cls_token_id = tokenizer.cls_token_id

interpretable_embedding:  InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \
                                                                                              'bert.embeddings')
interpretable_embedding1: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \
                                                                            'bert.embeddings.word_embeddings')
interpretable_embedding2: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \
                                                                            'bert.embeddings.token_type_embeddings')
interpretable_embedding3: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \
                                                                            'bert.embeddings.position_embeddings')




def construct_input_ref_pair(question: str, text: str, ref_token_id: int | str, sep_token_id: int | str, \
                             cls_token_id: int | str) \
                                 -> (torch.Tensor, torch.Tensor, int):
    question_ids: list = tokenizer.encode(question, add_special_tokens=False)
    text_ids: list = tokenizer.encode(text, add_special_tokens=False)