def _calculate_attributions(self, embeddings: Embedding): # type: ignore ( self.input_ids, self.ref_input_ids, self.sep_idx, ) = self._make_input_reference_pair(self.question, self.text) ( self.position_ids, self.ref_position_ids, ) = self._make_input_reference_position_id_pair(self.input_ids) ( self.token_type_ids, self.ref_token_type_ids, ) = self._make_input_reference_token_type_pair(self.input_ids, self.sep_idx) self.attention_mask = self._make_attention_mask(self.input_ids) if self.attribution_type == "lig": reference_tokens = [ token.replace("Ġ", "") for token in self.decode(self.input_ids) ] self.position = 0 start_lig = LIGAttributions( self._forward, embeddings, reference_tokens, self.input_ids, self.ref_input_ids, self.sep_idx, self.attention_mask, position_ids=self.position_ids, ref_position_ids=self.ref_position_ids, token_type_ids=self.token_type_ids, ref_token_type_ids=self.ref_token_type_ids, ) start_lig.summarize() self.start_attributions = start_lig self.position = 1 end_lig = LIGAttributions( self._forward, embeddings, reference_tokens, self.input_ids, self.ref_input_ids, self.sep_idx, self.attention_mask, position_ids=self.position_ids, ref_position_ids=self.ref_position_ids, token_type_ids=self.token_type_ids, ref_token_type_ids=self.ref_token_type_ids, ) end_lig.summarize() self.end_attributions = end_lig self.attributions = [self.start_attributions, self.end_attributions] else: pass
def _calculate_attributions( # type: ignore self, embeddings: Embedding, index: int = None, class_name: str = None): ( self.input_ids, self.ref_input_ids, self.sep_idx, ) = self._make_input_reference_pair(self.text) ( self.position_ids, self.ref_position_ids, ) = self._make_input_reference_position_id_pair(self.input_ids) self.attention_mask = self._make_attention_mask(self.input_ids) if index is not None: self.selected_index = index elif class_name is not None: if class_name in self.label2id.keys(): self.selected_index = self.label2id[class_name] else: s = f"'{class_name}' is not found in self.label2id keys." s += "Defaulting to predicted index instead." warnings.warn(s) self.selected_index = self.predicted_class_index else: self.selected_index = self.predicted_class_index if self.attribution_type == "lig": reference_tokens = [ token.replace("Ġ", "") for token in self.decode(self.input_ids) ] lig = LIGAttributions( self._forward, embeddings, reference_tokens, self.input_ids, self.ref_input_ids, self.sep_idx, self.attention_mask, position_ids=self.position_ids, ref_position_ids=self.ref_position_ids, ) lig.summarize() self.attributions = lig else: pass
def _calculate_attributions(self, index: int = None, class_name: str = None): ( self.input_ids, self.ref_input_ids, self.sep_idx, ) = self._make_input_reference_pair(self.text) if index is not None: self.selected_index = index elif class_name is not None: if class_name in self.label2id.keys(): self.selected_index = self.label2id[class_name] else: s = f"'{class_name}' is not found in self.label2id keys." s += "Defaulting to predicted index instead." warnings.warn(s) self.selected_index = self.predicted_class_index else: self.selected_index = self.predicted_class_index if self.attribution_type == "lig": embeddings = getattr(self.model, self.model_type).embeddings reference_text = "BOS_TOKEN " + self.text + " EOS_TOKEN" lig = LIGAttributions( self._forward, embeddings, reference_text, self.input_ids, self.ref_input_ids, self.sep_idx, ) lig.summarize() self.attributions = lig else: pass
def _calculate_attributions(self, embeddings: Embedding, class_name: str, index: int = None): # type: ignore ( self.input_ids, self.ref_input_ids, self.sep_idx, ) = self._make_input_reference_pair(self.text, self.hypothesis_text) ( self.position_ids, self.ref_position_ids, ) = self._make_input_reference_position_id_pair(self.input_ids) ( self.token_type_ids, self.ref_token_type_ids, ) = self._make_input_reference_token_type_pair(self.input_ids, self.sep_idx) self.attention_mask = self._make_attention_mask(self.input_ids) self.selected_index = int(self.label2id[class_name]) reference_tokens = [ token.replace("Ġ", "") for token in self.decode(self.input_ids) ] lig = LIGAttributions( self._forward, embeddings, reference_tokens, self.input_ids, self.ref_input_ids, self.sep_idx, self.attention_mask, position_ids=self.position_ids, ref_position_ids=self.ref_position_ids, token_type_ids=self.token_type_ids, ref_token_type_ids=self.ref_token_type_ids, internal_batch_size=self.internal_batch_size, n_steps=self.n_steps, ) if self.include_hypothesis: lig.summarize() else: lig.summarize(self.sep_idx) self.attributions.append(lig)