def initialize(self): print("initial tokenizer...") self.tokenizer = BertTokenizer.from_pretrained(self.tokenizer_path) self.ref_token_id = self.tokenizer.pad_token_id self.sep_token_id = self.tokenizer.sep_token_id self.cls_token_id = self.tokenizer.cls_token_id print("initial inference model...") self.model = BertForSequenceClassification.from_pretrained( self.model_path, num_labels=2).cpu().eval() self.model.zero_grad() print("initial lac model...") self.lac = LAC(mode="seg") self.lac.load_customization(self.lac_dict_path, sep="\t") print("initial interpretable embedding layers ...") self.interpretable_embedding1 = configure_interpretable_embedding_layer( self.model, 'bert.embeddings.word_embeddings') self.interpretable_embedding2 = configure_interpretable_embedding_layer( self.model, 'bert.embeddings.token_type_embeddings') self.interpretable_embedding3 = configure_interpretable_embedding_layer( self.model, 'bert.embeddings.position_embeddings') remove_interpretable_embedding_layer(self.model, self.interpretable_embedding1) remove_interpretable_embedding_layer(self.model, self.interpretable_embedding2) remove_interpretable_embedding_layer(self.model, self.interpretable_embedding3)
def __post_init__(self) -> None: self.special_token_mask = [ self.tokenizer.unk_token_id, self.tokenizer.sep_token_id, self.tokenizer.pad_token_id, self.tokenizer.cls_token_id ] self.candidate_token_mask, self.candidate_token_map = gen_candidate_mask( self.tokenizer) self.interpretable_embedding = configure_interpretable_embedding_layer( self.model, 'albert.embeddings.word_embeddings') self.model_perf_tups = { 'tweets': load_cache(f'{self.config.experiment.dirs.model_cache_dir}/' f'{constants.TWEET_MODEL_PERF_CACHE_NAME}'), 'nontweets': load_cache(f'{self.config.experiment.dirs.model_cache_dir}/' f'{constants.NONTWEET_MODEL_PERF_CACHE_NAME}'), 'global': load_cache(f'{self.config.experiment.dirs.model_cache_dir}/' f'{constants.GLOBAL_MODEL_PERF_CACHE_NAME}') } self.stmt_embed_dict = load_cache( f'{self.config.experiment.dirs.model_cache_dir}/' f'{constants.STMT_EMBED_CACHE_NAME}') if self.stmt_embed_dict: self.max_pred_idx = torch.tensor( np.argmax(np.asarray(self.stmt_embed_dict['preds']))) self.min_pred_idx = torch.tensor( np.argmin(np.asarray(self.stmt_embed_dict['preds']))) self.vis_data_records, self.ext_vis_data_records, self.ss_image_paths = [], [], []
def configure_embedding_layer(self): interp_embs = [] for layer in self.model: if isinstance(layer, nn.Embedding): interp_embs.append( configure_interpretable_embedding_layer(self.model, layer)) break
def predict_minibatch(self, inputs: List[JsonDict], config=None) -> List[JsonDict]: """ batch size set to 1 for simplicity, to use batch size greater than one, will need to use self.tokenizer.batch_encode_plus as in the LIT examples :param inputs: JSON of sentence and token to interpret :param config: :return: prediction output aligned with spec """ mask_token = '[MASK]' sentence = inputs[0]['Sentence'] interpret_token_id = inputs[0]['Token Index to Explain'] tokens = ['[CLS]'] + self.tokenizer.tokenize(sentence) + ['[SEP]'] input_ids = [self.tokenizer.convert_tokens_to_ids(tokens)] input_mask = [[1] * len(input_ids[0])] input_ids_tensor = torch.tensor(input_ids, dtype=torch.long) input_mask_tensor = torch.tensor(input_mask, dtype=torch.long) # Needed for calculating grad based on embeddings interpretable_embedding = configure_interpretable_embedding_layer( self.model, 'bert.embeddings.word_embeddings') input_embeddings = interpretable_embedding.indices_to_embeddings( input_ids_tensor) model_input = { "inputs_embeds": input_embeddings, "attention_mask": input_mask_tensor } model_output = self.model(**model_input) logits, embs, unused_attentions = model_output[:3] logits_ndarray = logits.detach().cpu().numpy() example_preds = np.argmax(logits_ndarray, axis=2) confidences = torch.softmax(torch.from_numpy(logits_ndarray), dim=2).detach().cpu().numpy() label_map = {i: label for i, label in enumerate(self.LABELS)} predictions = [label_map[pred] for pred in example_preds[0]] outputs = {} for i, attention_layer in enumerate(unused_attentions): outputs[f'layer_{i}/attention'] = attention_layer[0].detach().cpu( ).numpy().copy() # TODO Currently LIT lime explainer does not support targeting a specific token, until that's fixed, # we explain the first non-O index if there's one, or the first token (after [CLS]). if interpret_token_id < 0 or mask_token in sentence: scalar_output = np.where(example_preds[0] != 0)[0] token_index = scalar_output[0] if len(scalar_output > 1) else 1 else: # TODO When LIT lime explainer is configurable, we'll set the token_index from the UI token_index = interpret_token_id outputs['tokens'] = tokens outputs['bio_tags'] = predictions grad = torch.autograd.grad(torch.unbind(logits[0][token_index]), embs[0]) outputs['grads'] = grad[0][0].detach().cpu().numpy() outputs['probas'] = confidences[0][token_index] outputs['token_ids'] = list(range(0, len(tokens))) remove_interpretable_embedding_layer(self.model, interpretable_embedding) yield outputs
def attribute_predict(collate_fn, model_args, dataset, attribution_method, model, target, embedding_name): predicted_logits, attributions, token_ids_all, true_target = [], [], [], [] dl = torch.utils.data.DataLoader(batch_size=model_args.batch_size, dataset=dataset, collate_fn=collate_fn) if not isinstance(attribution_method, LimeBase): interpretable_embedding = configure_interpretable_embedding_layer(model,embedding_name) for batch in tqdm(dl): token_ids = batch[0] if model_args.model == 'lstm': additional_args = (batch[-1],) elif model_args.model == 'transformer': additional_args = (token_ids != 0,) else: additional_args = None if isinstance(attribution_method, LimeBase): inputs = token_ids instance_attribution = attribution_method.attribute(token_ids, n_perturb_samples=100, additional_forward_args=additional_args, target=target) else: inputs = interpretable_embedding.indices_to_embeddings(token_ids) instance_attribution = attribution_method.attribute(inputs, additional_forward_args=additional_args, target=target) instance_attribution = summarize_attributions(instance_attribution, type='mean').detach().cpu() predicted_logits += model(inputs, additional_args[0] if additional_args else None).detach().cpu().numpy().tolist() attributions += instance_attribution token_ids_all += token_ids.detach().cpu().numpy().tolist() true_target += batch[1].detach().cpu().numpy().tolist() if not isinstance(attribution_method, LimeBase): remove_interpretable_embedding_layer(model, interpretable_embedding) return predicted_logits, attributions, token_ids_all, true_target
# torch.argmax(x_pert,dim=1).item(), # goal_func_result.output, # 2, # atts_pert.sum().detach(), # all_tokens_pert[:45], # delta_pert) # post = { # "type": "captum", # "input_string": input_text, # "model_name": model_name, # "recipe_name": recipe_name, # "output_string": output_text # } interpretable_embedding = configure_interpretable_embedding_layer( clone.model, 'bert.embeddings') ref_token_id = original_tokenizer.tokenizer.pad_token_id sep_token_id = original_tokenizer.tokenizer.sep_token_id cls_token_id = original_tokenizer.tokenizer.cls_token_id def summarize_attributions(attributions): attributions = attributions.sum(dim=-1).squeeze(0) attributions = attributions / torch.norm(attributions) return attributions def construct_attention_mask(input_ids): return torch.ones_like(input_ids)
def captum_heatmap_interactive(request): if request.method == 'POST': STORED_POSTS = request.session.get("TextAttackResult") form = CustomData(request.POST) if form.is_valid(): input_text, model_name, recipe_name = form.cleaned_data[ 'input_text'], form.cleaned_data[ 'model_name'], form.cleaned_data['recipe_name'] found = False if STORED_POSTS: JSON_STORED_POSTS = json.loads(STORED_POSTS) for idx, el in enumerate(JSON_STORED_POSTS): if el["type"] == "heatmap" and el[ "input_string"] == input_text: tmp = JSON_STORED_POSTS.pop(idx) JSON_STORED_POSTS.insert(0, tmp) found = True break if found: request.session["TextAttackResult"] = json.dumps( JSON_STORED_POSTS[:10]) return HttpResponseRedirect(reverse('webdemo:index')) original_model = transformers.AutoModelForSequenceClassification.from_pretrained( "textattack/" + model_name) original_tokenizer = textattack.models.tokenizers.AutoTokenizer( "textattack/" + model_name) model = textattack.models.wrappers.HuggingFaceModelWrapper( original_model, original_tokenizer) device = torch.device( "cuda:2" if torch.cuda.is_available() else "cpu") clone = deepcopy(model) clone.model.to(device) def calculate(input_ids, token_type_ids, attention_mask): return clone.model(input_ids, token_type_ids, attention_mask)[0] interpretable_embedding = configure_interpretable_embedding_layer( clone.model, 'bert.embeddings') ref_token_id = original_tokenizer.tokenizer.pad_token_id sep_token_id = original_tokenizer.tokenizer.sep_token_id cls_token_id = original_tokenizer.tokenizer.cls_token_id def summarize_attributions(attributions): attributions = attributions.sum(dim=-1).squeeze(0) attributions = attributions / torch.norm(attributions) return attributions def construct_attention_mask(input_ids): return torch.ones_like(input_ids) def construct_input_ref_pos_id_pair(input_ids): seq_length = input_ids.size(1) position_ids = torch.arange(seq_length, dtype=torch.long, device=device) ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) ref_position_ids = ref_position_ids.unsqueeze(0).expand_as( input_ids) return position_ids, ref_position_ids def squad_pos_forward_func(inputs, token_type_ids=None, attention_mask=None): pred = calculate(inputs, token_type_ids, attention_mask) return pred.max(1).values def construct_input_ref_token_type_pair(input_ids, sep_ind=0): seq_len = input_ids.size(1) token_type_ids = torch.tensor( [[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device) ref_token_type_ids = torch.zeros_like(token_type_ids, device=device) # * -1 return token_type_ids, ref_token_type_ids input_text_ids = original_tokenizer.tokenizer.encode( input_text, add_special_tokens=False) input_ids = [cls_token_id] + input_text_ids + [sep_token_id] input_ids = torch.tensor([input_ids], device=device) position_ids, ref_position_ids = construct_input_ref_pos_id_pair( input_ids) ref_input_ids = [ cls_token_id ] + [ref_token_id] * len(input_text_ids) + [sep_token_id] ref_input_ids = torch.tensor([ref_input_ids], device=device) token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair( input_ids, len(input_text_ids)) attention_mask = torch.ones_like(input_ids) input_embeddings = interpretable_embedding.indices_to_embeddings( input_ids, token_type_ids=token_type_ids, position_ids=position_ids) ref_input_embeddings = interpretable_embedding.indices_to_embeddings( ref_input_ids, token_type_ids=ref_token_type_ids, position_ids=ref_position_ids) layer_attrs_start = [] for i in range(len(clone.model.bert.encoder.layer)): lc = LayerConductance(squad_pos_forward_func, clone.model.bert.encoder.layer[i]) layer_attributions_start = lc.attribute( inputs=input_embeddings, baselines=ref_input_embeddings, additional_forward_args=(token_type_ids, attention_mask))[0] layer_attrs_start.append( summarize_attributions( layer_attributions_start).cpu().detach().tolist()) all_tokens = original_tokenizer.tokenizer.convert_ids_to_tokens( input_ids[0]) fig, ax = plt.subplots(figsize=(15, 5)) xticklabels = all_tokens yticklabels = list(range(1, 13)) ax = sns.heatmap(np.array(layer_attrs_start), xticklabels=xticklabels, yticklabels=yticklabels, linewidth=0.2) plt.xlabel('Tokens') plt.ylabel('Layers') buf = io.BytesIO() fig.savefig(buf, format='png') buf.seek(0) bufferString = base64.b64encode(buf.read()) imageUri = urllib.parse.quote(bufferString) post = { "type": "heatmap", "input_string": input_text, "model_name": model_name, "recipe_name": recipe_name, "image": imageUri, } if STORED_POSTS: JSON_STORED_POSTS = json.loads(STORED_POSTS) JSON_STORED_POSTS.insert(0, post) request.session["TextAttackResult"] = json.dumps( JSON_STORED_POSTS[:10]) else: request.session["TextAttackResult"] = json.dumps([post]) return HttpResponseRedirect(reverse('webdemo:index')) else: return HttpResponseNotFound('Failed') return HttpResponse('Success') return HttpResponseNotFound('<h1>Not Found</h1>')
def predict(input_abstract: str, input_title: str, input_keywords: str, model: nn.Module, vectors: Vectors, output_vectors: Dictionary, device: torch.device): """ :param input_abstract: :param input_title: :param input_keywords: :param model: :param vectors: :param output_vectors: :param device: :return: """ # Prepare for visualization vis_data_records_ig = [] # Interpret sentence try: interpretable_embedding_abstracts = configure_interpretable_embedding_layer( model, 'embedding_abstracts') interpretable_embedding_titles = configure_interpretable_embedding_layer( model, 'embedding_titles') interpretable_embedding_keywords = configure_interpretable_embedding_layer( model, 'embedding_keywords') ig = IntegratedGradients(model) except: exit(1) print(f"Created Interpretable Layers: {datetime.datetime.now()}") interpret_sentence( model=model, input_abstract=input_abstract, input_title=input_title, input_keywords=input_keywords, vectors=vectors, interpretable_embedding_abstracts=interpretable_embedding_abstracts, interpretable_embedding_titles=interpretable_embedding_titles, interpretable_embedding_keywords=interpretable_embedding_keywords, ig=ig, vis_data_records_ig=vis_data_records_ig, output_vectors=output_vectors, device=device) print(f"Interpreted: {datetime.datetime.now()}") # Show interpretations #print(build_html(vis_data_records_ig)) json_data = build_json(vis_data_records_ig) print(f"Built JSON: {datetime.datetime.now()}") remove_interpretable_embedding_layer(model, interpretable_embedding_abstracts) remove_interpretable_embedding_layer(model, interpretable_embedding_titles) remove_interpretable_embedding_layer(model, interpretable_embedding_keywords) print(f"Removed Layers: {datetime.datetime.now()}") return json_data
batch_labels_ndarray = batch_labels.detach().cpu().numpy() if preds is None: preds = logits_ndarray out_label_ids = inputs["labels"].detach().cpu().numpy() else: preds = np.append(preds, logits_ndarray, axis=0) out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) example_preds = np.argmax(logits_ndarray, axis=2) if np.any(batch_labels_ndarray[0], where=batch_labels_ndarray[0] != 0): # interpretable_embedding = configure_interpretable_embedding_layer(deeplift_model, # 'model.bert.embeddings.word_embeddings') interpretable_embedding = configure_interpretable_embedding_layer( model, 'bert.embeddings.word_embeddings') input_embeddings = interpretable_embedding.indices_to_embeddings( input_ids) for token_index in np.where(batch_labels_ndarray[0] != 0)[0]: if out_label_ids[example_index][token_index] != pad_token_label_id: label_id = example_preds[0][token_index] true_label = batch_labels[0][token_index].item() target = (token_index, label_id) logger.info( f'Calculating attribution for label {label_id} at index {token_index}' ) attribution_start = time.time() # attributions, delta = explainer.attribute(input_embeddings, target=target, # additional_forward_args=batch_labels,
def generate(sample, lac): model.zero_grad() input_ids, ref_input_ids, sep_id = construct_input_ref_pair( sample.text, ref_token_id, sep_token_id, cls_token_id) token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair( input_ids, sep_id) position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids) attention_mask = construct_attention_mask(input_ids) indices = input_ids[0].detach().tolist() tokens = tokenizer.convert_ids_to_tokens(indices) input_ids = input_ids.to(device) ref_input_ids = ref_input_ids.to(device) token_type_ids = token_type_ids.to(device) ref_token_type_ids = ref_token_type_ids.to(device) position_ids = position_ids.to(device) ref_position_ids = ref_position_ids.to(device) attention_mask = attention_mask.to(device) logits = forward_func(input_ids, \ token_type_ids=token_type_ids, \ position_ids=position_ids, \ attention_mask=attention_mask) pred = int(logits.argmax(1)) score = float(logits[0][pred]) interpretable_embedding1 = configure_interpretable_embedding_layer( model, 'bert.embeddings.word_embeddings') interpretable_embedding2 = configure_interpretable_embedding_layer( model, 'bert.embeddings.token_type_embeddings') interpretable_embedding3 = configure_interpretable_embedding_layer( model, 'bert.embeddings.position_embeddings') (input_embed, ref_input_embed), (token_type_ids_embed, ref_token_type_ids_embed), (position_ids_embed, ref_position_ids_embed) = construct_bert_sub_embedding(input_ids, ref_input_ids, \ token_type_ids=token_type_ids, ref_token_type_ids=ref_token_type_ids, \ position_ids=position_ids, ref_position_ids=ref_position_ids) lig = IntegratedGradients(forward_func) attributions, delta = lig.attribute( inputs=(input_embed, token_type_ids_embed, position_ids_embed), baselines=(ref_input_embed, ref_token_type_ids_embed, ref_position_ids_embed), target=sample.label, additional_forward_args=(attention_mask), return_convergence_delta=True) input_ids = input_ids.cpu() ref_input_ids = ref_input_ids.cpu() token_type_ids = token_type_ids.cpu() ref_token_type_ids = ref_token_type_ids.cpu() position_ids = position_ids.cpu() ref_position_ids = ref_position_ids.cpu() attention_mask = attention_mask.cpu() input_embed = input_embed.cpu() ref_input_embed = ref_input_embed.cpu() token_type_ids_embed = token_type_ids_embed.cpu() ref_token_type_ids_embed = ref_token_type_ids_embed.cpu() position_ids_embed = position_ids_embed.cpu() ref_position_ids_embed = ref_position_ids_embed.cpu() torch.cuda.empty_cache() _, attribution_words = word_level_spline( tokens, summarize_attributions(attributions[0]).cpu().detach().numpy(), lac) _, attribution_position = word_level_spline( tokens, summarize_attributions(attributions[2]).cpu().detach().numpy(), lac) words = [each[0] for each in attribution_words] attribution_merge = [ attribution_words[i][1] + attribution_position[i][1] for i in range(len(attribution_words)) ] remove_interpretable_embedding_layer(model, interpretable_embedding1) remove_interpretable_embedding_layer(model, interpretable_embedding2) remove_interpretable_embedding_layer(model, interpretable_embedding3) return Result(words=words, label=pred, attribution=attribution_merge)
model_path = "checkpoints/roberta24" model = BertForSequenceClassification.from_pretrained(model_path, num_labels=2) model.to(device) model.eval() model.zero_grad() tokenizer_path = "hfl/chinese-roberta-wwm-ext" tokenizer = BertTokenizer.from_pretrained(tokenizer_path) ref_token_id = tokenizer.pad_token_id sep_token_id = tokenizer.sep_token_id cls_token_id = tokenizer.cls_token_id print("prepare interpretable embedding...") interpretable_embedding1 = configure_interpretable_embedding_layer( model, 'bert.embeddings.word_embeddings') interpretable_embedding2 = configure_interpretable_embedding_layer( model, 'bert.embeddings.token_type_embeddings') interpretable_embedding3 = configure_interpretable_embedding_layer( model, 'bert.embeddings.position_embeddings') remove_interpretable_embedding_layer(model, interpretable_embedding1) remove_interpretable_embedding_layer(model, interpretable_embedding2) remove_interpretable_embedding_layer(model, interpretable_embedding3) def predict(inputs, token_type_ids=None, position_ids=None, attention_mask=None): return model( inputs,
def generate_saliency(model_path, saliency_path, saliency, aggregation): checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model_args = Namespace(**checkpoint['args']) if args.model == 'lstm': model = LSTM_MODEL(tokenizer, model_args, n_labels=checkpoint['args']['labels']).to(device) model.load_state_dict(checkpoint['model']) elif args.model == 'trans': transformer_config = BertConfig.from_pretrained( 'bert-base-uncased', num_labels=model_args.labels) model_cp = BertForSequenceClassification.from_pretrained( 'bert-base-uncased', config=transformer_config).to(device) checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model_cp.load_state_dict(checkpoint['model']) model = BertModelWrapper(model_cp) else: model = CNN_MODEL(tokenizer, model_args, n_labels=checkpoint['args']['labels']).to(device) model.load_state_dict(checkpoint['model']) model.train() pad_to_max = False if saliency == 'deeplift': ablator = DeepLift(model) elif saliency == 'guided': ablator = GuidedBackprop(model) elif saliency == 'sal': ablator = Saliency(model) elif saliency == 'inputx': ablator = InputXGradient(model) elif saliency == 'occlusion': ablator = Occlusion(model) coll_call = get_collate_fn(dataset=args.dataset, model=args.model) return_attention_masks = args.model == 'trans' collate_fn = partial(coll_call, tokenizer=tokenizer, device=device, return_attention_masks=return_attention_masks, pad_to_max_length=pad_to_max) test = get_dataset(path=args.dataset_dir, mode=args.split, dataset=args.dataset) batch_size = args.batch_size if args.batch_size != None else \ model_args.batch_size test_dl = DataLoader(batch_size=batch_size, dataset=test, shuffle=False, collate_fn=collate_fn) # PREDICTIONS predictions_path = model_path + '.predictions' if not os.path.exists(predictions_path): predictions = defaultdict(lambda: []) for batch in tqdm(test_dl, desc='Running test prediction... '): if args.model == 'trans': logits = model(batch[0], attention_mask=batch[1], labels=batch[2].long()) else: logits = model(batch[0]) logits = logits.detach().cpu().numpy().tolist() predicted = np.argmax(np.array(logits), axis=-1) predictions['class'] += predicted.tolist() predictions['logits'] += logits with open(predictions_path, 'w') as out: json.dump(predictions, out) # COMPUTE SALIENCY if saliency != 'occlusion': embedding_layer_name = 'model.bert.embeddings' if args.model == \ 'trans' else \ 'embedding' interpretable_embedding = configure_interpretable_embedding_layer( model, embedding_layer_name) class_attr_list = defaultdict(lambda: []) token_ids = [] saliency_flops = [] for batch in tqdm(test_dl, desc='Running Saliency Generation...'): if args.model == 'cnn': additional = None elif args.model == 'trans': additional = (batch[1], batch[2]) else: additional = batch[-1] token_ids += batch[0].detach().cpu().numpy().tolist() if saliency != 'occlusion': input_embeddings = interpretable_embedding.indices_to_embeddings( batch[0]) if not args.no_time: high.start_counters([ events.PAPI_FP_OPS, ]) for cls_ in range(checkpoint['args']['labels']): if saliency == 'occlusion': attributions = ablator.attribute( batch[0], sliding_window_shapes=(args.sw, ), target=cls_, additional_forward_args=additional) else: attributions = ablator.attribute( input_embeddings, target=cls_, additional_forward_args=additional) attributions = summarize_attributions( attributions, type=aggregation, model=model, tokens=batch[0]).detach().cpu().numpy().tolist() class_attr_list[cls_] += [[_li for _li in _l] for _l in attributions] if not args.no_time: saliency_flops.append( sum(high.stop_counters()) / batch[0].shape[0]) if saliency != 'occlusion': remove_interpretable_embedding_layer(model, interpretable_embedding) # SERIALIZE print('Serializing...', flush=True) with open(saliency_path, 'w') as out: for instance_i, _ in enumerate(test): saliencies = [] for token_i, token_id in enumerate(token_ids[instance_i]): token_sal = {'token': tokenizer.ids_to_tokens[token_id]} for cls_ in range(checkpoint['args']['labels']): token_sal[int( cls_)] = class_attr_list[cls_][instance_i][token_i] saliencies.append(token_sal) out.write(json.dumps({'tokens': saliencies}) + '\n') out.flush() return saliency_flops
def predict(self, text): self.model.zero_grad() input_ids, ref_input_ids, sep_id = self.construct_input_ref_pair( text, self.ref_token_id, self.sep_token_id, self.cls_token_id) token_type_ids, ref_token_type_ids = self.construct_input_ref_token_type_pair( input_ids, sep_id) position_ids, ref_position_ids = self.construct_input_ref_pos_id_pair( input_ids) attention_mask = self.construct_attention_mask(input_ids) indices = input_ids[0].detach().tolist() tokens = self.tokenizer.convert_ids_to_tokens(indices) logits = self.forward_func(input_ids, \ token_type_ids=token_type_ids, \ position_ids=position_ids, \ attention_mask=attention_mask) pred = int(logits.argmax(1)) logits = logits.tolist()[0] self.interpretable_embedding1 = configure_interpretable_embedding_layer( self.model, 'bert.embeddings.word_embeddings') self.interpretable_embedding2 = configure_interpretable_embedding_layer( self.model, 'bert.embeddings.token_type_embeddings') self.interpretable_embedding3 = configure_interpretable_embedding_layer( self.model, 'bert.embeddings.position_embeddings') (input_embed, ref_input_embed), (token_type_ids_embed, ref_token_type_ids_embed), (position_ids_embed, ref_position_ids_embed) = self.construct_bert_sub_embedding(input_ids, ref_input_ids, \ token_type_ids=token_type_ids, ref_token_type_ids=ref_token_type_ids, \ position_ids=position_ids, ref_position_ids=ref_position_ids) lig = IntegratedGradients(self.forward_func) attr = [] for label in range(len(logits)): attributions, delta = lig.attribute( inputs=(input_embed, token_type_ids_embed, position_ids_embed), baselines=(ref_input_embed, ref_token_type_ids_embed, ref_position_ids_embed), target=label, additional_forward_args=(attention_mask), return_convergence_delta=True) _, attribution_words = self.word_level_spline( tokens, self.summarize_attributions( attributions[0]).cpu().detach().numpy(), self.lac) _, attribution_position = self.word_level_spline( tokens, self.summarize_attributions( attributions[2]).cpu().detach().numpy(), self.lac) if len(attribution_words) != 0: words = [each[0] for each in attribution_words] attribution_merge = [ attribution_words[i][1] + attribution_position[i][1] for i in range(len(attribution_words)) ] range_limit = np.max(np.abs(attribution_merge)) attribution_merge /= range_limit attr.append(list(attribution_merge)) else: words = [] attr.append([]) remove_interpretable_embedding_layer(self.model, self.interpretable_embedding1) remove_interpretable_embedding_layer(self.model, self.interpretable_embedding2) remove_interpretable_embedding_layer(self.model, self.interpretable_embedding3) return words, logits, pred, attr
position_ids=position_ids, attention_mask=attention_mask) pred = pred[position] return pred.max(1).values # Optional[int] ref_token_id = tokenizer.pad_token_id # Optional[int] sep_token_id = tokenizer.sep_token_id # Optional[int] cls_token_id = tokenizer.cls_token_id interpretable_embedding: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \ 'bert.embeddings') interpretable_embedding1: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \ 'bert.embeddings.word_embeddings') interpretable_embedding2: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \ 'bert.embeddings.token_type_embeddings') interpretable_embedding3: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \ 'bert.embeddings.position_embeddings') def construct_input_ref_pair(question: str, text: str, ref_token_id: int | str, sep_token_id: int | str, \ cls_token_id: int | str) \ -> (torch.Tensor, torch.Tensor, int): question_ids: list = tokenizer.encode(question, add_special_tokens=False) text_ids: list = tokenizer.encode(text, add_special_tokens=False)