def _calculate_attributions(self, embeddings: Embedding):  # type: ignore

        (
            self.input_ids,
            self.ref_input_ids,
            self.sep_idx,
        ) = self._make_input_reference_pair(self.question, self.text)

        (
            self.position_ids,
            self.ref_position_ids,
        ) = self._make_input_reference_position_id_pair(self.input_ids)

        (
            self.token_type_ids,
            self.ref_token_type_ids,
        ) = self._make_input_reference_token_type_pair(self.input_ids, self.sep_idx)

        self.attention_mask = self._make_attention_mask(self.input_ids)
        if self.attribution_type == "lig":
            reference_tokens = [
                token.replace("Ġ", "") for token in self.decode(self.input_ids)
            ]
            self.position = 0
            start_lig = LIGAttributions(
                self._forward,
                embeddings,
                reference_tokens,
                self.input_ids,
                self.ref_input_ids,
                self.sep_idx,
                self.attention_mask,
                position_ids=self.position_ids,
                ref_position_ids=self.ref_position_ids,
                token_type_ids=self.token_type_ids,
                ref_token_type_ids=self.ref_token_type_ids,
            )
            start_lig.summarize()
            self.start_attributions = start_lig

            self.position = 1
            end_lig = LIGAttributions(
                self._forward,
                embeddings,
                reference_tokens,
                self.input_ids,
                self.ref_input_ids,
                self.sep_idx,
                self.attention_mask,
                position_ids=self.position_ids,
                ref_position_ids=self.ref_position_ids,
                token_type_ids=self.token_type_ids,
                ref_token_type_ids=self.ref_token_type_ids,
            )
            end_lig.summarize()
            self.end_attributions = end_lig
            self.attributions = [self.start_attributions, self.end_attributions]
        else:
            pass
Пример #2
0
    def _calculate_attributions(  # type: ignore
            self,
            embeddings: Embedding,
            index: int = None,
            class_name: str = None):
        (
            self.input_ids,
            self.ref_input_ids,
            self.sep_idx,
        ) = self._make_input_reference_pair(self.text)

        (
            self.position_ids,
            self.ref_position_ids,
        ) = self._make_input_reference_position_id_pair(self.input_ids)

        self.attention_mask = self._make_attention_mask(self.input_ids)

        if index is not None:
            self.selected_index = index
        elif class_name is not None:
            if class_name in self.label2id.keys():
                self.selected_index = self.label2id[class_name]
            else:
                s = f"'{class_name}' is not found in self.label2id keys."
                s += "Defaulting to predicted index instead."
                warnings.warn(s)
                self.selected_index = self.predicted_class_index
        else:
            self.selected_index = self.predicted_class_index
        if self.attribution_type == "lig":
            reference_tokens = [
                token.replace("Ġ", "") for token in self.decode(self.input_ids)
            ]
            lig = LIGAttributions(
                self._forward,
                embeddings,
                reference_tokens,
                self.input_ids,
                self.ref_input_ids,
                self.sep_idx,
                self.attention_mask,
                position_ids=self.position_ids,
                ref_position_ids=self.ref_position_ids,
            )
            lig.summarize()
            self.attributions = lig
        else:
            pass
Пример #3
0
    def _calculate_attributions(self,
                                index: int = None,
                                class_name: str = None):
        (
            self.input_ids,
            self.ref_input_ids,
            self.sep_idx,
        ) = self._make_input_reference_pair(self.text)

        if index is not None:
            self.selected_index = index
        elif class_name is not None:
            if class_name in self.label2id.keys():
                self.selected_index = self.label2id[class_name]
            else:
                s = f"'{class_name}' is not found in self.label2id keys."
                s += "Defaulting to predicted index instead."
                warnings.warn(s)
                self.selected_index = self.predicted_class_index
        else:
            self.selected_index = self.predicted_class_index

        if self.attribution_type == "lig":
            embeddings = getattr(self.model, self.model_type).embeddings
            reference_text = "BOS_TOKEN " + self.text + " EOS_TOKEN"
            lig = LIGAttributions(
                self._forward,
                embeddings,
                reference_text,
                self.input_ids,
                self.ref_input_ids,
                self.sep_idx,
            )
            lig.summarize()
            self.attributions = lig
        else:
            pass
Пример #4
0
    def _calculate_attributions(self,
                                embeddings: Embedding,
                                class_name: str,
                                index: int = None):  # type: ignore
        (
            self.input_ids,
            self.ref_input_ids,
            self.sep_idx,
        ) = self._make_input_reference_pair(self.text, self.hypothesis_text)

        (
            self.position_ids,
            self.ref_position_ids,
        ) = self._make_input_reference_position_id_pair(self.input_ids)

        (
            self.token_type_ids,
            self.ref_token_type_ids,
        ) = self._make_input_reference_token_type_pair(self.input_ids,
                                                       self.sep_idx)

        self.attention_mask = self._make_attention_mask(self.input_ids)

        self.selected_index = int(self.label2id[class_name])

        reference_tokens = [
            token.replace("Ġ", "") for token in self.decode(self.input_ids)
        ]
        lig = LIGAttributions(
            self._forward,
            embeddings,
            reference_tokens,
            self.input_ids,
            self.ref_input_ids,
            self.sep_idx,
            self.attention_mask,
            position_ids=self.position_ids,
            ref_position_ids=self.ref_position_ids,
            token_type_ids=self.token_type_ids,
            ref_token_type_ids=self.ref_token_type_ids,
            internal_batch_size=self.internal_batch_size,
            n_steps=self.n_steps,
        )
        if self.include_hypothesis:
            lig.summarize()
        else:
            lig.summarize(self.sep_idx)
        self.attributions.append(lig)