Exemple #1
0
def interpret_sentence(model, text, text_lengths, args, label=0):

    # Interpretable method
    if 'BERT' in args.model:
        PAD_IND = args.bert_tokenizer.pad_token_id
        lig = LayerIntegratedGradients(model, model.model.embeddings)
    else:
        PAD_IND = args.TEXT.vocab.stoi['<pad>']
        lig = LayerIntegratedGradients(model, model.embedding)
    token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)

    model.zero_grad()

    # predict
    start = time.time()
    pred = model(text, text_lengths).squeeze(0)
    print("time:", time.time() - start)
    pred_ind = torch.argmax(pred).item()

    # generate reference indices for each sample
    reference_indices = token_reference.generate_reference(
        text.shape[1], device=args.device).unsqueeze(0)

    # compute attributions and approximation delta using layer integrated gradients
    attributions_ig_1 = lig.attribute((text, text_lengths),
                                      (reference_indices, text_lengths),
                                      target=0,
                                      n_steps=100,
                                      return_convergence_delta=False)

    attributions_ig_2 = lig.attribute((text, text_lengths),
                                      (reference_indices, text_lengths),
                                      target=1,
                                      n_steps=100,
                                      return_convergence_delta=False)

    if 'BERT' in args.model:
        sentence = [
            args.bert_tokenizer.ids_to_tokens[int(word)]
            for word in text.squeeze(0).cpu().numpy()
            if int(word) != args.bert_tokenizer.pad_token_id
        ]
    else:
        sentence = [
            args.TEXT.vocab.itos[int(word)]
            for word in text.squeeze(0).cpu().numpy()
        ]
    # print(sentence)

    add_attributions_to_visualizer(attributions_ig_1, sentence, pred, pred_ind,
                                   label, args)
    add_attributions_to_visualizer(attributions_ig_2, sentence, pred, pred_ind,
                                   label, args)
def get_attributions(model, text):
    """

    Returns:
        - tokens: An array of tokens
        - attrs: An array of attributions, of same size as 'tokens',
          with attrs[i] being the attribution to tokens[i]

     """

    # tokenize text
    tokenized = tokenizer.encode_plus(text,
                                      pad_to_max_length=True,
                                      max_length=512)
    input_ids = torch.tensor(tokenized['input_ids']).to(device)
    input_ids = input_ids.view((1, -1))

    tokenized = [x for x in tokenized['input_ids'] if x != 0]
    tokenized_text = tokenizer.convert_ids_to_tokens(tokenized)

    lig = LayerIntegratedGradients(model, model.bert.embeddings)
    attributions, delta = lig.attribute(input_ids,
                                        internal_batch_size=10,
                                        return_convergence_delta=True)

    attributions = attributions.sum(dim=-1)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions[0][:len(tokenized)]

    return tokenized_text, attributions, delta
Exemple #3
0
def get_lig_object(model, model_class_item):
    insight_supported = model_class_item['insight_supported'] if 'insight_supported' in model_class_item else False
    internal_model_name = model_class_item['internal_model_name']
    lig = None  # default is None.
    if not insight_supported:
        logger.info(f"Inspection for model '{model_class_item['model_class_name']}' is not supported.")
        return lig

    if isinstance(internal_model_name, list):
        current_layer = model
        for layer_n in internal_model_name:
            current_layer = current_layer.__getattr__(layer_n)
        # print(current_layer)
        lig = LayerIntegratedGradients(model, current_layer)
    else:
        lig = LayerIntegratedGradients(get_model_prediction,
                                       model.__getattr__(internal_model_name).embeddings.word_embeddings)
    return lig
 def initialize(self, model_path):
     print("initial tokenizer...")
     self.tokenizer = DataIterator().tokenizer
     self.PAD_IND = self.tokenizer.vocab.stoi['<pad>']
     self.token_reference = TokenReferenceBase(
         reference_token_idx=self.PAD_IND)
     print("initial inference model...")
     self.model = torch.load(model_path, map_location="cpu").eval()
     print("initial attribution method ... ")
     self.lig = LayerIntegratedGradients(self.model, self.model.embedding)
Exemple #5
0
    def __init__(
        self,
        custom_forward: Callable,
        embeddings: nn.Module,
        tokens: list,
        input_ids: torch.Tensor,
        ref_input_ids: torch.Tensor,
        sep_id: int,
        attention_mask: torch.Tensor,
        token_type_ids: torch.Tensor = None,
        position_ids: torch.Tensor = None,
        ref_token_type_ids: torch.Tensor = None,
        ref_position_ids: torch.Tensor = None,
    ):
        super().__init__(custom_forward, embeddings, tokens)
        self.input_ids = input_ids
        self.ref_input_ids = ref_input_ids
        self.attention_mask = attention_mask
        self.token_type_ids = token_type_ids
        self.position_ids = position_ids
        self.ref_token_type_ids = ref_token_type_ids
        self.ref_position_ids = ref_position_ids

        self.lig = LayerIntegratedGradients(self.custom_forward,
                                            self.embeddings)

        if self.token_type_ids is not None and self.position_ids is not None:
            self._attributions, self.delta = self.lig.attribute(
                inputs=(self.input_ids, self.token_type_ids,
                        self.position_ids),
                baselines=(
                    self.ref_input_ids,
                    self.ref_token_type_ids,
                    self.ref_position_ids,
                ),
                return_convergence_delta=True,
                additional_forward_args=(self.attention_mask),
            )
        elif self.position_ids is not None:
            self._attributions, self.delta = self.lig.attribute(
                inputs=(self.input_ids, self.position_ids),
                baselines=(
                    self.ref_input_ids,
                    self.ref_position_ids,
                ),
                return_convergence_delta=True,
                additional_forward_args=(self.attention_mask),
            )
        else:

            self._attributions, self.delta = self.lig.attribute(
                inputs=self.input_ids,
                baselines=self.ref_input_ids,
                return_convergence_delta=True,
            )
 def attribution(self):
     #self.logits = self.model(self.input_ids, token_type_ids=self.token_type_ids, attention_mask=self.attention_mask, )
     #self.prediction = torch.argmax(self.logits[0])
     #self.sentclass_pos_forward_func = self.logits[0].max(1).values
     
     lig = LayerIntegratedGradients(self.sentclass_pos_forward_func, self.model.bert.embeddings)
     attributions_start, self.delta_start = lig.attribute(inputs = self.input_ids,
                                                     baselines = self.ref_input_ids,
                                                     additional_forward_args=(self.token_type_ids, self.attention_mask),
                                                     return_convergence_delta=True)
     attributions_start = attributions_start.sum(dim=-1).squeeze(0)
     self.attributions_start_summary = attributions_start / torch.norm(attributions_start)
     #self.attributions_start_summary = self.attributions_start_summary.detach().tolist()
     return self.attributions_start_summary
Exemple #7
0
def main(cfg):
    # Initialize the dataset
    blastchar_dataset = BlastcharDataset(cfg.dataset.path)
    NUM_CATEGORICAL_COLS = blastchar_dataset.num_categorical_cols
    NUM_CONTINIOUS_COLS = blastchar_dataset.num_continious_cols
    EMBED_DIM = 32

    # initialize the model with its arguments
    mlp = nn.Sequential(
        nn.Linear(NUM_CATEGORICAL_COLS * EMBED_DIM + NUM_CONTINIOUS_COLS, 50),
        nn.ReLU(), nn.BatchNorm1d(50), nn.Dropout(cfg.params.dropout),
        nn.Linear(50, 20), nn.ReLU(), nn.BatchNorm1d(20),
        nn.Dropout(cfg.params.dropout),
        nn.Linear(20, blastchar_dataset.num_classes))

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = TabTransformer(blastchar_dataset.num_categories,
                           mlp,
                           embed_dim=EMBED_DIM,
                           num_cont_cols=NUM_CONTINIOUS_COLS)
    model.load_state_dict(torch.load(cfg.params.weights), strict=False)
    model = model.to(device)
    model.eval()

    model = ModelInputWrapper(model)

    cat, cont, _ = blastchar_dataset[0]
    cat, cont = cat.unsqueeze(0).long(), cont.unsqueeze(0).float()
    cat = torch.cat((cat, cat), dim=0)
    cont = torch.cat((cont, cont), dim=0)
    input = (cat, cont)

    outs = model(*input)
    preds = outs.argmax(-1)

    attr = LayerIntegratedGradients(
        model, [model.module.embed, model.module.layer_norm])

    attributions, _ = attr.attribute(
        inputs=(cat, cont),
        baselines=(torch.zeros_like(cat, dtype=torch.long),
                   torch.zeros_like(cont, dtype=torch.float32)),
        target=preds.detach(),
        n_steps=30,
        return_convergence_delta=True)

    print(f'attributions: {attributions[0].shape, attributions[1].shape}')
    pprint(torch.cat((attributions[0].sum(dim=2), attributions[1]), dim=1))
 def initialize(self, context):
     """
     Loads the model and Initializes the necessary artifacts
     """
     super().initialize(context)
     self.initialized = False
     source_vocab = self.manifest['model'][
         'sourceVocab'] if 'sourceVocab' in self.manifest['model'] else None
     if source_vocab:
         # Backward compatibility
         self.source_vocab = torch.load(source_vocab)
     else:
         self.source_vocab = torch.load(self.get_source_vocab_path(context))
     #Captum initialization
     self.lig = LayerIntegratedGradients(self.model, self.model.embedding)
     self.initialized = True
 def __init__(
     self,
     custom_forward: Callable,
     embeddings: nn.Module,
     text: str,
     input_ids: torch.Tensor,
     ref_input_ids: torch.Tensor,
     sep_id: int,
 ):
     super().__init__(custom_forward, embeddings, text)
     self.input_ids = input_ids
     self.ref_input_ids = ref_input_ids
     self.lig = LayerIntegratedGradients(self.custom_forward, self.embeddings)
     self._attributions, self.delta = self.lig.attribute(
         inputs=self.input_ids,
         baselines=self.ref_input_ids,
         return_convergence_delta=True,
     )
Exemple #10
0
def attribute_integrated_gradients(
        text_input_ids: torch.Tensor, ref_input_ids: torch.Tensor, target: int,
        model: BertForSequenceClassification,
        **kwargs) -> Tuple[np.ndarray, Dict[str, float]]:
    def forward(model_input):
        pred = model(model_input)
        return torch.softmax(pred[0], dim=1)

    lig = LayerIntegratedGradients(forward, model.bert.embeddings)

    attributions, delta = lig.attribute(inputs=text_input_ids,
                                        target=target,
                                        baselines=ref_input_ids,
                                        return_convergence_delta=True)

    scores = attributions.sum(dim=-1).squeeze(0)
    scores = scores.cpu().detach().numpy()

    return scores, {"delta": delta.item()}
Exemple #11
0
def lig_explain(inputs: Any, target: int, forward: Callable,
                embedding_layer: nn.Module) -> torch.Tensor:
    """Interpretability algorithm (Integrated Gradients) that assigns
    an importance score to each input token

    Args:
        inputs: Input for token embedding layer.
        target (int): Index of label for interpretation.
        forward (Callable): The forward function of the model or any
            modification of it.
        embedding_layer: Token embedding layer for which attributions are
            computed.

    Returns:
        Tensor of importance score to each input token
    """
    lig = LayerIntegratedGradients(forward, embedding_layer)
    attributions = lig.attribute(inputs, target=target)
    attributions = reduce_embedding_attributions(attributions)
    return attributions
Exemple #12
0
    def get_scores_and_attributions(self, inputs, tok_e1_idx, tok_e2_idx, str_label):
        input_ids, attention_mask = inputs["input_ids"], \
                                    inputs["attention_mask"]
    
        input_ids_tensor, ref_input_ids_tensor = self._construct_input_ref_pair(input_ids)
        #token_type_ids_tensor, ref_token_type_ids_tensor = self._construct_input_ref_token_type_pair(token_type_ids)
        attention_mask_tensor = torch.tensor([attention_mask],device=self.device)
        e1_pos_tensor = torch.tensor([tok_e1_idx], device=self.device)
        e2_pos_tensor = torch.tensor([tok_e2_idx], device=self.device)
        labels_tensor = torch.tensor([CLASSES.index(str_label)], device=self.device)
        

        indices = input_ids_tensor[0].detach().tolist()
        all_tokens = self.tokenizer.convert_ids_to_tokens(indices)

        _, pred_scores, _ = self.predict(input_ids_tensor,
                                         #token_type_ids=token_type_ids_tensor,
                                         attention_mask=attention_mask_tensor,
                                         labels=labels_tensor,
                                         e1_pos=e1_pos_tensor,
                                         e2_pos=e2_pos_tensor)

        lig = LayerIntegratedGradients(self.trc_forward_func, self.model.roberta.embeddings)

        attributions, delta = lig.attribute(inputs=input_ids_tensor,
                                      baselines=ref_input_ids_tensor,
                                      additional_forward_args=(None,#token_type_ids_tensor,
                                                               attention_mask_tensor,
                                                               labels_tensor,
                                                               e1_pos_tensor,
                                                               e2_pos_tensor),
                                      return_convergence_delta=True)

        attributions_sum = summarize_attributions(attributions)

        return pred_scores, all_tokens, attributions_sum, delta
Exemple #13
0
    print(attributions_ig)
    add_attributions_to_visualizer(attributions_ig, tokens, token_ids, pred,
                                   pred_ind, label, delta, vis_data_records_ig)
    visualization.visualize_text(vis_data_records_ig)


device = torch.device(f'cuda:3' if torch.cuda.is_available() else 'cpu')
data = pkl.load(open(f"data/{datasets}/graph/ind.{datasets}_id", 'rb'),
                encoding='latin1')
adj = sp.csr_matrix(
    (data.edge_attr.cpu().numpy(),
     (data.edge_index[0].cpu().numpy(), data.edge_index[1].cpu().numpy())),
    shape=(data.x.shape[0], data.x.shape[0]))
adj = adj.toarray()

model = GCNNet(node_size=29426,
               embed_dim=200,
               hidden_dim=256,
               embedding_finetune=True,
               num_class=2,
               dropout=0.5,
               layers=2).to(device)
model.eval()
model.load_state_dict(
    torch.load(
        '/mnt/nlp-lq/yujunshuai/code/explainable_GCN/experiments/mr_id.pt'))
data = data.to(device)
lig = LayerIntegratedGradients(text_gcn_forward_func, model.embed)

interpret_sentence(0)
    def initialize(self, ctx):
        """In this initialize function, the BERT model is loaded and
        the Layer Integrated Gradients Algorithmfor Captum Explanations
        is initialized here.

        Args:
            ctx (context): It is a JSON Object containing information
            pertaining to the model artefacts parameters.
        """
        self.manifest = ctx.manifest
        properties = ctx.system_properties
        model_dir = properties.get("model_dir")
        serialized_file = self.manifest["model"]["serializedFile"]
        model_pt_path = os.path.join(model_dir, serialized_file)
        self.device = torch.device("cuda:" +
                                   str(properties.get("gpu_id")) if torch.cuda.
                                   is_available() else "cpu")
        # read configs for the mode, model_name, etc. from setup_config.json
        setup_config_path = os.path.join(model_dir, "setup_config.json")
        if os.path.isfile(setup_config_path):
            with open(setup_config_path) as setup_config_file:
                self.setup_config = json.load(setup_config_file)
        else:
            logger.warning("Missing the setup_config.json file.")

        # Loading the model and tokenizer from checkpoint and config files based on the user's choice of mode
        # further setup config can be added.
        if self.setup_config["save_mode"] == "torchscript":
            self.model = torch.jit.load(model_pt_path)
        elif self.setup_config["save_mode"] == "pretrained":
            if self.setup_config["mode"] == "sequence_classification":
                self.model = AutoModelForSequenceClassification.from_pretrained(
                    model_dir)
            elif self.setup_config["mode"] == "question_answering":
                self.model = AutoModelForQuestionAnswering.from_pretrained(
                    model_dir)
            elif self.setup_config["mode"] == "token_classification":
                self.model = AutoModelForTokenClassification.from_pretrained(
                    model_dir)
            else:
                logger.warning("Missing the operation mode.")
        else:
            logger.warning("Missing the checkpoint or state_dict.")

        if any(fname for fname in os.listdir(model_dir)
               if fname.startswith("vocab.") and os.path.isfile(fname)):
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_dir, do_lower_case=self.setup_config["do_lower_case"])
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.setup_config["model_name"],
                do_lower_case=self.setup_config["do_lower_case"],
            )

        self.model.to(self.device)
        self.model.eval()

        logger.info("Transformer model from path %s loaded successfully",
                    model_dir)

        # Read the mapping file, index to object name
        mapping_file_path = os.path.join(model_dir, "index_to_name.json")
        # Question answering does not need the index_to_name.json file.
        if not self.setup_config["mode"] == "question_answering":
            if os.path.isfile(mapping_file_path):
                with open(mapping_file_path) as f:
                    self.mapping = json.load(f)
            else:
                logger.warning("Missing the index_to_name.json file.")

            # ------------------------------- Captum initialization ----------------------------#
        self.lig = LayerIntegratedGradients(captum_sequence_forward,
                                            self.model.bert.embeddings)
        self.initialized = True
def main(model_path, n_steps=50):
    #pylint: disable=missing-docstring, too-many-locals
    n_steps = int(n_steps)

    # load the model and tokenizer
    model = load_deprecated_model(str(model_path))
    tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')

    # tokenize the sentence for classification
    sequence = """Some might not find it to totally be like Pokémon without Ash.
    But this is definitely a Pokémon movie and way better than most of their animated movies.
    The CGI nailed the looks of all the Pokémon creatures and their voices.
    The movie is charming, funny and fun as well. They did a great job introducing this world to the
    big screen. I definitely want more."""

    input_ids, token_type_ids, attention_mask = prepare_input(
        sequence, tokenizer)

    # create a baseline of zeros in the same shape as the inputs
    baseline_ids = torch.zeros(input_ids.shape, dtype=torch.int64)

    #change following to intermediate gradients

    # create an instance of layer intermediate gradients based upon the embedding layer
    lig = LayerIntermediateGradients(sequence_forward_func,
                                     model.bert.embeddings)
    grads, step_sizes = lig.attribute(inputs=input_ids,
                                      baselines=baseline_ids,
                                      additional_forward_args=(model,
                                                               token_type_ids,
                                                               attention_mask),
                                      n_steps=n_steps)

    print("Shape of the returned gradients: ")
    print(grads.shape)
    print("Shape of the step sizes: ")
    print(step_sizes.shape)

    # now calculate attributions from the intermediate gradients

    # multiply by the step sizes
    scaled_grads = grads.view(n_steps, -1) * step_sizes
    # reshape and sum along the num_steps dimension
    scaled_grads = torch.sum(scaled_grads.reshape((n_steps, 1) +
                                                  grads.shape[1:]),
                             dim=0)
    # pass forward the input and baseline ids for reference
    forward_input_ids = model.bert.embeddings.forward(input_ids)
    forward_baseline_ids = model.bert.embeddings.forward(baseline_ids)
    # multiply the scaled gradients by the difference of inputs and baselines to obtain attributions
    attributions = scaled_grads * (forward_input_ids - forward_baseline_ids)
    print("Attributions calculated from intermediate gradients: ")
    print(attributions.shape)
    print(attributions)

    # compare to layer integrated gradients
    layer_integrated = LayerIntegratedGradients(sequence_forward_func,
                                                model.bert.embeddings)
    attrs = layer_integrated.attribute(
        inputs=input_ids,
        baselines=baseline_ids,
        additional_forward_args=(model, token_type_ids, attention_mask),
        n_steps=n_steps,
        return_convergence_delta=False)
    print("Attributions from layer integrated gradients: ")
    print(attrs.shape)
    print(attrs)
Exemple #16
0
# multi-gpu evaluate
if args.n_gpu > 1:
    model = torch.nn.DataParallel(model)

# Eval!
logger.info("***** Running evaluation %s *****", prefix)
logger.info("  Num examples = %d", len(eval_dataset))
logger.info("  Batch size = %d", args.eval_batch_size)
eval_loss = 0.0
nb_eval_steps = 0
preds = None
out_label_ids = None
model.eval()

explainer = LayerIntegratedGradients(predict_with_embeddings,
                                     model.bert.embeddings)
# deeplift_model = NerModel(model)
# explainer = LayerDeepLift(deeplift_model, deeplift_model.model.bert.embeddings)

example_index = 0

all_attributions = []

for batch in tqdm(eval_dataloader, desc="Evaluating"):
    example_attrs = []
    batch = tuple(t.to(args.device) for t in batch)

    input_ids = batch[0]
    attention_mask = batch[1]
    segment_ids = batch[2]
    batch_labels = batch[3]
Exemple #17
0
    def __init__(
        self,
        custom_forward: Callable,
        embeddings: nn.Module,
        tokens: list,
        input_ids: torch.Tensor,
        ref_input_ids: torch.Tensor,
        sep_id: int,
        attention_mask: torch.Tensor,
        target: Optional[Union[int, Tuple, torch.Tensor, List]] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        ref_token_type_ids: Optional[torch.Tensor] = None,
        ref_position_ids: Optional[torch.Tensor] = None,
        internal_batch_size: Optional[int] = None,
        n_steps: int = 50,
    ):
        super().__init__(custom_forward, embeddings, tokens)
        self.input_ids = input_ids
        self.ref_input_ids = ref_input_ids
        self.attention_mask = attention_mask
        self.target = target
        self.token_type_ids = token_type_ids
        self.position_ids = position_ids
        self.ref_token_type_ids = ref_token_type_ids
        self.ref_position_ids = ref_position_ids
        self.internal_batch_size = internal_batch_size
        self.n_steps = n_steps

        self.lig = LayerIntegratedGradients(self.custom_forward,
                                            self.embeddings)

        if self.token_type_ids is not None and self.position_ids is not None:
            self._attributions, self.delta = self.lig.attribute(
                inputs=(self.input_ids, self.token_type_ids,
                        self.position_ids),
                baselines=(
                    self.ref_input_ids,
                    self.ref_token_type_ids,
                    self.ref_position_ids,
                ),
                target=self.target,
                return_convergence_delta=True,
                additional_forward_args=(self.attention_mask),
                internal_batch_size=self.internal_batch_size,
                n_steps=self.n_steps,
            )
        elif self.position_ids is not None:
            self._attributions, self.delta = self.lig.attribute(
                inputs=(self.input_ids, self.position_ids),
                baselines=(
                    self.ref_input_ids,
                    self.ref_position_ids,
                ),
                target=self.target,
                return_convergence_delta=True,
                additional_forward_args=(self.attention_mask),
                internal_batch_size=self.internal_batch_size,
                n_steps=self.n_steps,
            )
        elif self.token_type_ids is not None:
            self._attributions, self.delta = self.lig.attribute(
                inputs=(self.input_ids, self.token_type_ids),
                baselines=(
                    self.ref_input_ids,
                    self.ref_token_type_ids,
                ),
                target=self.target,
                return_convergence_delta=True,
                additional_forward_args=(self.attention_mask),
                internal_batch_size=self.internal_batch_size,
                n_steps=self.n_steps,
            )

        else:
            self._attributions, self.delta = self.lig.attribute(
                inputs=self.input_ids,
                baselines=self.ref_input_ids,
                target=self.target,
                return_convergence_delta=True,
                internal_batch_size=self.internal_batch_size,
                n_steps=self.n_steps,
            )
Exemple #18
0
    
    print('Epoch: {}, train loss: {}, val loss: {}, train acc: {}, val acc: {}, train f1: {}, val f1: {}'.format(epoch, train_loss, val_loss, train_acc, val_acc, train_f1, val_f1))
model.load_state_dict(best_model)


metrics = test_evaluating(model, test_iter, criterion)
metrics["test_f1"]

!pip install -q captum

from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization

PAD_IND = TEXT.vocab.stoi['pad']

token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)
lig = LayerIntegratedGradients(model, model.embedding)

def forward_with_softmax(inp):
    logits = model(inp)
    return torch.softmax(logits, 0)[0][1]

def forward_with_sigmoid(input):
    return torch.sigmoid(model(input))


# accumalate couple samples in this array for visualization purposes
vis_data_records_ig = []

def interpret_sentence(model, sentence, min_len = 7, label = 0):
    model.eval()
    text = [tok for tok in TEXT.tokenize(sentence)]
def run_models(model_name, model, tokenizer, sequence, device, baseline):
    """
    Run Integrated and Intermediate gradients on the model layer.

    Parameters
    ----------
    model_name: str
       Name of the model that is being run.
       Currently supported are "Bert" or "XLNet"
    model: torch.nn.Module
       Module to run 
    tokenizer: transformers.tokenizer
       Tokenizer to process the sequence and produce the input ids
    sequence: str
       Sequence to get the gradients from.
    device: torch.device
       Device that models are stored on.
    baseline: str
       Baseline to run with integrated gradients. Currently supported are 'zero', 'pad', 'unk',
       'rand-norm', 'rand-unif', and 'period'.

    Returns
    -------
    gradients_dict: dict
        Dictionary containing the gradient tensors with the following keys:
        "integrated_gradients", "intermediate_gradients", "step_sizes", and "intermediates".
    """
    features = prepare_input(sequence, tokenizer)
    input_ids = features["input_ids"].to(device)
    token_type_ids = features["token_type_ids"].to(device)
    attention_mask = features["attention_mask"].to(device)

    # set up gradients and the baseline ids
    if model_name == "bert":
        layer_interm = LayerIntermediateGradients(bert_sequence_forward_func, model.bert.embeddings)
        lig = LayerIntegratedGradients(bert_sequence_forward_func, model.bert.embeddings)
        baseline_ids = generate_bert_baselines(baseline, input_ids, tokenizer).to(device)
    elif model_name == "xlnet":
        layer_interm = LayerIntermediateGradients(
            xlnet_sequence_forward_func, model.transformer.batch_first
        )
        lig = LayerIntegratedGradients(xlnet_sequence_forward_func, model.transformer.batch_first)
        baseline_ids = generate_xlnet_baselines(baseline, input_ids, tokenizer).to(device)

    grads, step_sizes, intermediates = layer_interm.attribute(inputs=input_ids,
                                                              baselines=baseline_ids,
                                                              additional_forward_args=(
                                                                  model,
                                                                  token_type_ids,
                                                                  attention_mask
                                                              ),
                                                              target=1,
                                                              n_steps=50) # maybe pass n_steps as CLI argument

    integrated_grads = lig.attribute(inputs=input_ids,
                                     baselines=baseline_ids,
                                     additional_forward_args=(
                                         model,
                                         token_type_ids,
                                         attention_mask
                                     ),
                                     target=1,
                                     n_steps=50)

    grads_dict = {"intermediate_grads": grads.to("cpu"),
                  "step_sizes": step_sizes.to("cpu"),
                  "intermediates": intermediates.to("cpu"),
                  "integrated_grads": integrated_grads.to("cpu")}

    return grads_dict
                                                     
def captum_subseq(model,
                  data_loader,
                  data_iterator,
                  metrics,
                  params,
                  num_steps,
                  before_after=2,
                  top_std=2,
                  mlp=False):
    """top_std : top 2:25 % sequences (weights > mean + top_std*std)
    """
    mod = "Base" if not attention_model else "Attn"
    mod = net_name + "-" + mod
    fname = f'above_top{top_std}std_subseqs_testData_beforeAfter{before_after}_{mod}Model_{seqmethod}Method.csv'
    print(fname)

    model.eval()

    vis_data_records = []  # passed in a reference

    try:
        if seqmethod == 'intgrad':
            layer_ig = LayerIntegratedGradients(model, model.embedding)
            interpret_sequence_copy(model,
                                    data_loader,
                                    data_iterator,
                                    layer_ig,
                                    vis_data_records,
                                    num_steps,
                                    verbose=False,
                                    mlp=mlp)
        else:
            layer_sal = Saliency(model)
            interpret_sequence_copy(model,
                                    data_loader,
                                    data_iterator,
                                    layer_sal,
                                    vis_data_records,
                                    num_steps,
                                    verbose=True)
    except Exception as e:
        print(e)

    outseq = pd.DataFrame()
    before_after = 2
    print("Extracting subsequences")
    for i, vd in enumerate(vis_data_records):
        sequence = vd.raw_input
        weights = vd.word_attributions
        predicted_label = vd.pred_class
        true_label = vd.true_class
        sampleIdx = i
        out = _subseq(sequence, weights, top_std, predicted_label, true_label,
                      before_after, sampleIdx)
        # print(out)
        outseq = outseq.append(out)
        # if i >
    seqdir = 'subsequences'
    if not os.path.exists(seqdir): os.makedirs(seqdir)

    fname = os.path.join(seqdir, fname)
    outseq.to_csv(fname, index=False)
    print("Saved to", fname)
Exemple #21
0
    def get_insights(self, input_batch, text, target):
        """This function initialize and calls the layer integrated gradient to get word importance
        of the input text if captum explanation has been selected through setup_config
        Args:
            input_batch (int): Batches of tokens IDs of text
            text (str): The Text specified in the input request
            target (int): The Target can be set to any acceptable label under the user's discretion.
        Returns:
            (list): Returns a list of importances and words.
        """

        if self.setup_config["captum_explanation"]:
            embedding_layer = getattr(self.model,
                                      self.setup_config["embedding_name"])
            embeddings = embedding_layer.embeddings
            self.lig = LayerIntegratedGradients(captum_sequence_forward,
                                                embeddings)
        else:
            logger.warning(
                "Captum Explanation is not chosen and will not be available")

        if isinstance(text, (bytes, bytearray)):
            text = text.decode('utf-8')
        text_target = ast.literal_eval(text)

        if not self.setup_config["mode"] == "question_answering":
            text = text_target["text"]
        self.target = text_target["target"]

        input_ids, ref_input_ids, attention_mask = construct_input_ref(
            text, self.tokenizer, self.device, self.setup_config["mode"])
        all_tokens = get_word_token(input_ids, self.tokenizer)
        response = {}
        response["words"] = all_tokens
        if self.setup_config[
                "mode"] == "sequence_classification" or self.setup_config[
                    "mode"] == "token_classification":

            attributions, delta = self.lig.attribute(
                inputs=input_ids,
                baselines=ref_input_ids,
                target=self.target,
                additional_forward_args=(attention_mask, 0, self.model),
                return_convergence_delta=True,
            )

            attributions_sum = summarize_attributions(attributions)
            response["importances"] = attributions_sum.tolist()
            response["delta"] = delta[0].tolist()

        elif self.setup_config["mode"] == "question_answering":
            attributions_start, delta_start = self.lig.attribute(
                inputs=input_ids,
                baselines=ref_input_ids,
                target=self.target,
                additional_forward_args=(attention_mask, 0, self.model),
                return_convergence_delta=True,
            )
            attributions_end, delta_end = self.lig.attribute(
                inputs=input_ids,
                baselines=ref_input_ids,
                target=self.target,
                additional_forward_args=(attention_mask, 1, self.model),
                return_convergence_delta=True,
            )
            attributions_sum_start = summarize_attributions(attributions_start)
            attributions_sum_end = summarize_attributions(attributions_end)
            response[
                "importances_answer_start"] = attributions_sum_start.tolist()
            response["importances_answer_end"] = attributions_sum_end.tolist()
            response["delta_start"] = delta_start[0].tolist()
            response["delta_end"] = delta_end[0].tolist()

        return [response]
def captum_text_interpreter(text,
                            model,
                            bpetokenizer,
                            idx2label,
                            max_len=80,
                            tokenizer=None,
                            multiclass=False):
    if type(text) == list:
        text = " ".join(text)

    d = data_utils.process_data_for_transformers(text, bpetokenizer, tokenizer,
                                                 0)
    d = {
        "ids": torch.tensor([d['ids']], dtype=torch.long),
        "mask": torch.tensor([d['mask']], dtype=torch.long),
        "token_type_ids": torch.tensor([d['token_type_ids']], dtype=torch.long)
    }

    try:
        orig_tokens = [0] + bpetokenizer.encode(text).ids + [2]
        orig_tokens = [bpetokenizer.id_to_token(j) for j in orig_tokens]
    except:
        orig_tokens = tokenizer.tokenize(text, add_special_tokens=True)

    model.eval()
    if multiclass:
        preds_proba = torch.sigmoid(
            model(d["ids"], d["mask"],
                  d["token_type_ids"])).detach().cpu().numpy()
        preds = preds_proba.argmax(-1)
        preds_proba = preds_proba[0][preds[0][0]]
        predicted_class = idx2label[preds[0][0]]
    else:
        preds_proba = torch.sigmoid(
            model(d["ids"], d["mask"],
                  d["token_type_ids"])).detach().cpu().numpy()
        preds = np.round(preds_proba)
        preds_proba = preds_proba[0][0]
        predicted_class = idx2label[preds[0][0]]

    lig = LayerIntegratedGradients(model, model.base_model.embeddings)

    reference_indices = [0] + [1] * (d["ids"].shape[1] - 2) + [2]
    reference_indices = torch.tensor([reference_indices], dtype=torch.long)

    attributions_ig, delta = lig.attribute(inputs=d["ids"],baselines=reference_indices,additional_forward_args=(d["mask"],d["token_type_ids"]), \
                                           return_convergence_delta=True)

    attributions = attributions_ig.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.detach().cpu().numpy()

    visualization.visualize_text([
        visualization.VisualizationDataRecord(word_attributions=attributions,
                                              pred_prob=preds_proba,
                                              pred_class=predicted_class,
                                              true_class=predicted_class,
                                              attr_class=predicted_class,
                                              attr_score=attributions.sum(),
                                              raw_input=orig_tokens,
                                              convergence_score=delta)
    ])
Exemple #23
0
def interpret_main(text, label):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = load_model(
        '/Users/andrewmendez1/Documents/ai-ml-challenge-2020/data/Finetune BERT oversampling 8_16_2020/Model_1_4_0/model.pt',
        device)

    def predict(inputs):
        #print('model(inputs): ', model(inputs))
        return model.encoder(inputs)[0]

    def custom_forward(inputs):
        preds = predict(inputs)
        return torch.softmax(preds, dim=1)[:, 0]

    # load tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    ref_token_id = tokenizer.pad_token_id  # A token used for generating token reference
    sep_token_id = tokenizer.sep_token_id  # A token used as a separator between question and text and it is also added to the end of the text.
    cls_token_id = tokenizer.cls_token_id  # A token used for prepending to the concatenated question-text word sequence
    hook = model.encoder.bert.embeddings.register_forward_hook(save_act)
    hook.remove()

    input_ids, ref_input_ids, sep_id = construct_input_ref_pair(
        text, tokenizer, device, ref_token_id, sep_token_id, cls_token_id)
    token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(
        input_ids, device, sep_id)
    position_ids, ref_position_ids = construct_input_ref_pos_id_pair(
        input_ids, device)
    attention_mask = construct_attention_mask(input_ids)

    # text = "the exclusion of implied warranties is not permitted by some the above exclusion may not apply to"# label 0

    lig = LayerIntegratedGradients(custom_forward,
                                   model.encoder.bert.embeddings)
    # attributions_main, delta_main = lig.attribute(inputs=input_ids,baselines=ref_input_ids,return_convergence_delta=True,n_steps=30)
    t0 = time()
    attributions, delta = lig.attribute(
        inputs=input_ids,
        baselines=ref_input_ids,
        # n_steps=7000,
        # internal_batch_size=5,
        return_convergence_delta=True,
        n_steps=300)
    st.write("Time to complete interpretation: {} seconds".format(time() - t0))
    # print("Time in {} minutes".format( (time()-t0)/60 ))
    attributions_sum = summarize_attributions(attributions)

    all_tokens = tokenizer.convert_ids_to_tokens(
        input_ids[0].detach().tolist())
    top_tokens, values, indicies = get_topk_attributed_tokens(attributions_sum,
                                                              all_tokens,
                                                              k=7)
    st.subheader("Top Tokens that the Model decided Unacceptability")
    import numpy as np
    plt.figure(figsize=(12, 6))
    x_pos = np.arange(len(values))
    plt.bar(x_pos, values.detach().numpy(), align='center')
    plt.xticks(x_pos, top_tokens, wrap=True)
    plt.xlabel("Tokens")
    plt.title(
        "Top 5 Tokens that made the model classify clause as unacceptable")
    st.pyplot()

    st.subheader(
        "Detailed Table showing Attribution Score to each word in clause")
    st.write(" ")
    st.write(
        "Positive Attributions mean that the words/tokens were \"positively\" attributed to the models's prediction."
    )
    st.write(
        "Negative Attributions mean that the words/tokens were \"negatively\" attributed to the models's prediction."
    )

    # res = ['{}({}) {:.3f}'.format(token, str(i),attributions_sum[i]) for i, token in enumerate(all_tokens)]
    df = pd.DataFrame({
        'Words': all_tokens,
        'Attributions': attributions_sum.detach().numpy()
    })
    st.table(df)
    score = predict(input_ids)
    score_vis = viz.VisualizationDataRecord(
        attributions_sum,
        torch.softmax(score, dim=1)[0][0],
        torch.argmax(torch.softmax(score, dim=1)[0]), label, text,
        attributions_sum.sum(), all_tokens, delta)
    print('\033[1m', 'Visualization For Score', '\033[0m')
    # from IPython.display import display, HTML, Image
    # viz.visualize_text([score_vis])
    # st.write(display(Image(viz.visualize_text([score_vis])) ) )

    # open('output.png', 'wb').write(im.data)
    # st.pyplot()


# text= "this license shall be effective until company in its sole and absolute at any time and for any or no disable the or suspend or terminate this license and the rights afforded to you with or without prior notice or other action by upon the termination of this you shall cease all use of the app and uninstall the company will not be liable to you or any third party for or damages of any sort as a result of terminating this license in accordance with its and termination of this license will be without prejudice to any other right or remedy company may now or in the these obligations survive termination of this"
# # label=1
# label = "?"
# main(text,label)
def run_models(model, model_name, num_trials, subset, tokenized_list, device):
    if model_name == "bert":
        layer_interm = LayerIntermediateGradients(bert_sequence_forward_func,
                                                  model.bert.embeddings)
        lig = LayerIntegratedGradients(bert_sequence_forward_func,
                                       model.bert.embeddings)
    elif model_name == "xlnet":
        layer_interm = LayerIntermediateGradients(
            xlnet_sequence_forward_func, model.transformer.batch_first)
        lig = LayerIntegratedGradients(xlnet_sequence_forward_func,
                                       model.transformer.batch_first)

    run_through_example = tokenized_list[-1]
    tokenized_list = tokenized_list[:subset]

    input_ids = run_through_example["input_ids"].to(device)
    token_type_ids = run_through_example["token_type_ids"].to(device)
    attention_mask = run_through_example["attention_mask"].to(device)
    baseline_ids = run_through_example["baseline_ids"].to(device)

    grads, step_sizes, intermediates = layer_interm.attribute(
        inputs=input_ids,
        baselines=baseline_ids,
        additional_forward_args=(model, token_type_ids, attention_mask),
        target=1,
        n_steps=50)  # maybe pass n_steps as CLI argument

    integrated_grads = lig.attribute(inputs=input_ids,
                                     baselines=baseline_ids,
                                     additional_forward_args=(model,
                                                              token_type_ids,
                                                              attention_mask),
                                     target=1,
                                     n_steps=50)

    for repetition in tqdm(range(num_trials)):
        start_time = time.perf_counter()
        for feature in tokenized_list:
            input_ids = feature["input_ids"].to(device)
            token_type_ids = feature["token_type_ids"].to(device)
            attention_mask = feature["attention_mask"].to(device)
            baseline_ids = feature["baseline_ids"].to(device)

            grads, step_sizes, intermediates = layer_interm.attribute(
                inputs=input_ids,
                baselines=baseline_ids,
                additional_forward_args=(model, token_type_ids,
                                         attention_mask),
                target=1,
                n_steps=50)  # maybe pass n_steps as CLI argument

            integrated_grads = lig.attribute(
                inputs=input_ids,
                baselines=baseline_ids,
                additional_forward_args=(model, token_type_ids,
                                         attention_mask),
                target=1,
                n_steps=50)
        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        print(
            "Repetition %s Elapsed Time for %s examples: " %
            (repetition, subset), elapsed_time)
Exemple #25
0
def captum_interactive(request):
    if request.method == 'POST':
        STORED_POSTS = request.session.get("TextAttackResult")
        form = CustomData(request.POST)
        if form.is_valid():
            input_text, model_name, recipe_name = form.cleaned_data[
                'input_text'], form.cleaned_data[
                    'model_name'], form.cleaned_data['recipe_name']
            found = False
            if STORED_POSTS:
                JSON_STORED_POSTS = json.loads(STORED_POSTS)
                for idx, el in enumerate(JSON_STORED_POSTS):
                    if el["type"] == "captum" and el[
                            "input_string"] == input_text:
                        tmp = JSON_STORED_POSTS.pop(idx)
                        JSON_STORED_POSTS.insert(0, tmp)
                        found = True
                        break

                if found:
                    request.session["TextAttackResult"] = json.dumps(
                        JSON_STORED_POSTS[:10])
                    return HttpResponseRedirect(reverse('webdemo:index'))

            original_model = transformers.AutoModelForSequenceClassification.from_pretrained(
                "textattack/" + model_name)
            original_tokenizer = textattack.models.tokenizers.AutoTokenizer(
                "textattack/" + model_name)
            model = textattack.models.wrappers.HuggingFaceModelWrapper(
                original_model, original_tokenizer)

            device = torch.device(
                "cuda:2" if torch.cuda.is_available() else "cpu")
            clone = deepcopy(model)
            clone.model.to(device)

            def calculate(input_ids, token_type_ids, attention_mask):
                return clone.model(input_ids, token_type_ids,
                                   attention_mask)[0]

            attack = textattack.commands.attack.attack_args_helpers.parse_attack_from_args(
                Args(model_name, recipe_name))
            attacked_text = textattack.shared.attacked_text.AttackedText(
                input_text)
            attack.goal_function.init_attack_example(attacked_text, 1)
            goal_func_result, _ = attack.goal_function.get_result(
                attacked_text)

            result = next(
                attack.attack_dataset([(input_text, goal_func_result.output)]))
            result_parsed = result.str_lines()
            if len(result_parsed) < 3:
                return HttpResponseNotFound('Failed')
            output_text = result_parsed[2]

            attacked_text_out = textattack.shared.attacked_text.AttackedText(
                output_text)

            orig = result.original_text()
            pert = result.perturbed_text()

            encoded = model.tokenizer.batch_encode([orig])
            batch_encoded = captum_form(encoded, device)
            x = calculate(**batch_encoded)

            pert_encoded = model.tokenizer.batch_encode([pert])
            pert_batch_encoded = captum_form(pert_encoded, device)
            x_pert = calculate(**pert_batch_encoded)

            lig = LayerIntegratedGradients(calculate,
                                           clone.model.bert.embeddings)
            attributions, delta = lig.attribute(
                inputs=batch_encoded['input_ids'],
                additional_forward_args=(batch_encoded['token_type_ids'],
                                         batch_encoded['attention_mask']),
                n_steps=10,
                target=torch.argmax(calculate(**batch_encoded)).item(),
                return_convergence_delta=True)

            attributions_pert, delta_pert = lig.attribute(
                inputs=pert_batch_encoded['input_ids'],
                additional_forward_args=(pert_batch_encoded['token_type_ids'],
                                         pert_batch_encoded['attention_mask']),
                n_steps=10,
                target=torch.argmax(calculate(**pert_batch_encoded)).item(),
                return_convergence_delta=True)

            orig = original_tokenizer.tokenizer.tokenize(orig)
            pert = original_tokenizer.tokenizer.tokenize(pert)

            atts = attributions.sum(dim=-1).squeeze(0)
            atts = atts / torch.norm(atts)

            atts_pert = attributions_pert.sum(dim=-1).squeeze(0)
            atts_pert = atts_pert / torch.norm(atts)

            all_tokens = original_tokenizer.tokenizer.convert_ids_to_tokens(
                batch_encoded['input_ids'][0])
            all_tokens_pert = original_tokenizer.tokenizer.convert_ids_to_tokens(
                pert_batch_encoded['input_ids'][0])

            v = viz.VisualizationDataRecord(atts[:45].detach().cpu(),
                                            torch.max(x).item(),
                                            torch.argmax(x, dim=1).item(),
                                            goal_func_result.output, 2,
                                            atts.sum().detach(),
                                            all_tokens[:45], delta)

            v_pert = viz.VisualizationDataRecord(
                atts_pert[:45].detach().cpu(),
                torch.max(x_pert).item(),
                torch.argmax(x_pert, dim=1).item(), goal_func_result.output, 2,
                atts_pert.sum().detach(), all_tokens_pert[:45], delta_pert)

            formattedHTML = formatDisplay([v, v_pert])

            post = {
                "type": "captum",
                "input_string": input_text,
                "model_name": model_name,
                "recipe_name": recipe_name,
                "output_string": output_text,
                "html_input_string": formattedHTML[0],
                "html_output_string": formattedHTML[1],
            }

            if STORED_POSTS:
                JSON_STORED_POSTS = json.loads(STORED_POSTS)
                JSON_STORED_POSTS.insert(0, post)
                request.session["TextAttackResult"] = json.dumps(
                    JSON_STORED_POSTS[:10])
            else:
                request.session["TextAttackResult"] = json.dumps([post])

            return HttpResponseRedirect(reverse('webdemo:index'))

        else:
            return HttpResponseNotFound('Failed')

        return HttpResponse('Success')

    return HttpResponseNotFound('<h1>Not Found</h1>')
Exemple #26
0
def main(model_path, n_steps=50):
    #pylint: disable=missing-docstring, too-many-locals

    # disable warning messages for initial pretrained XLNet module.
    logging.basicConfig(level=logging.ERROR)
    n_steps = int(n_steps)

    # load the model and tokenizer
    model = load_model(str(model_path), device=torch.device("cpu"))
    tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')

    # tokenize the sentence for classification
    sequence = """Some might not find it to totally be like Pokémon without Ash.
    But this is definitely a Pokémon movie and way better than most of their animated movies.
    The CGI nailed the looks of all the Pokémon creatures and their voices.
    The movie is charming, funny and fun as well. They did a great job introducing this world to the
    big screen. I definitely want more."""

    features = prepare_input(sequence, tokenizer)
    input_ids = features["input_ids"]
    token_type_ids = features["token_type_ids"]
    attention_mask = features["attention_mask"]

    # create a baseline of zeros in the same shape as the inputs
    baseline_ids = torch.zeros(input_ids.shape, dtype=torch.int64)

    # instance of layer intermediate gradients based upon the dummy layer representing the embeddings
    lig = LayerIntermediateGradients(sequence_forward_func,
                                     model.transformer.batch_first)
    grads, step_sizes, intermediates = lig.attribute(
        inputs=input_ids,
        baselines=baseline_ids,
        additional_forward_args=(model, token_type_ids, attention_mask),
        target=1,
        n_steps=n_steps)

    print("Shape of the returned gradients: ")
    print(grads.shape)
    print("Shape of the step sizes: ")
    print(step_sizes.shape)

    # now calculate attributions from the intermediate gradients

    # multiply by the step sizes
    scaled_grads = grads.view(n_steps, -1) * step_sizes
    # reshape and sum along the num_steps dimension
    scaled_grads = torch.sum(scaled_grads.reshape((n_steps, 1) +
                                                  grads.shape[1:]),
                             dim=0)
    # pass forward the input and baseline ids for reference
    forward_input_ids = model.transformer.word_embedding.forward(input_ids)
    forward_baseline_ids = model.transformer.word_embedding.forward(
        baseline_ids)
    # multiply the scaled gradients by the difference of inputs and baselines to obtain attributions
    attributions = scaled_grads * (forward_input_ids - forward_baseline_ids)
    print("Attributions calculated from intermediate gradients: ")
    print(attributions.shape)
    print(attributions)

    # compare to layer integrated gradients
    layer_integrated = LayerIntegratedGradients(sequence_forward_func,
                                                model.transformer.batch_first)
    attrs = layer_integrated.attribute(
        inputs=input_ids,
        baselines=baseline_ids,
        additional_forward_args=(model, token_type_ids, attention_mask),
        n_steps=n_steps,
        target=1,
        return_convergence_delta=False)
    print("Attributions from layer integrated gradients: ")
    print(attrs.shape)
    print(attrs)

    print("Intermediate tensor shape: ", intermediates.shape)
    print("Intermediate tensor: ", intermediates)
Exemple #27
0
    return position_ids, ref_position_ids


def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)


def custom_forward(inputs):
    preds = predict(inputs)
    return torch.softmax(
        preds, dim=1
    )[:,
      1]  # for negative attribution, torch.softmax(preds, dim = 1)[:, 1] <- for positive attribution


lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)


def get_attribution_for_test_set(lig, test_data_set):
    words_ls = []
    attributions_ls = []
    test_set_word_att_dict = {}

    for index, row in test_data_set.iterrows():
        text = row["Text"]
        clean_text = row["reduced_text_clean"]
        oh_label = row['oh_label']

        input_ids, ref_input_ids, sep_id = construct_input_ref_pair(
            clean_text, ref_token_id, sep_token_id, cls_token_id)
        token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(
def compute_and_output_attributions(
				outcome='top_level'
		):

		import pickle

		print ('Loading data ...')
		
		if outcome == 'top_level':
			prepared_data_file = PREPARED_DATA_FILE_top_level
		elif outcome == 'mn_avg_eb':
			prepared_data_file = PREPARED_DATA_FILE_mn_avg_eb
		elif outcome == 'mn_avg_eb_adv':
			prepared_data_file = PREPARED_DATA_FILE_mn_avg_eb_adv
		elif outcome == 'perwht':
			prepared_data_file = PREPARED_DATA_FILE_perwht
		elif outcome == 'perfrl':
			prepared_data_file = PREPARED_DATA_FILE_perfrl
		else:
			prepared_data_file = PREPARED_DATA_FILE_mn_grd_eb

		df = pd.read_csv(RAW_DATA_FILE)
		with open(prepared_data_file, 'rb') as f:
			all_input_ids, labels_target, attention_masks, sentences_per_school, url, perwht, perfrl, share_singleparent, totenrl, share_collegeplus, mail_returnrate = pickle.load(f, encoding='latin1')

		device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
		print ('Loading model ...')
		model, BEST_MODEL_DIR = get_best_model(outcome)

		model.to(device)
		model.zero_grad()

		# load tokenizer
		tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

		# Define wrapper function for integrated gradients
		def bert_forward_wrapper(input_ids, num_sentences, attention_mask=None, position=0):
				return model(input_ids, num_sentences, attention_mask=attention_mask)

		from captum.attr import TokenReferenceBase
		from captum.attr import IntegratedGradients, LayerIntegratedGradients
		from captum.attr import visualization as viz

		# We only want to compute IG over schools in our validation set
		data_splits = ['validation']
		all_summarized_attr = []
		input_ids_for_attr = []
		count = 0

		internal_batch_size = 12
		n_steps = 48

		OUTPUT_DIR = '{}interp/attributions/{}/'
		OUTPUT_FILE = OUTPUT_DIR + '{}_{}_loss_{}.json'
		if not os.path.exists(OUTPUT_DIR.format(BASE_DIR, BEST_MODEL_DIR)):
			os.makedirs(OUTPUT_DIR.format(BASE_DIR, BEST_MODEL_DIR))

		start_ind = len([int(f.split('_')[0]) for f in os.listdir(OUTPUT_DIR.format(BASE_DIR, BEST_MODEL_DIR))])

		for d in data_splits:

			# Standardize our outcome measure, like we did for training and validation
			outcome_key = outcome.split('_adv')[0]
			labels_target[d] = torch.FloatTensor((labels_target[d] - np.mean(df[outcome_key])) / np.std(df[outcome_key]))
			        
			n_schools = torch.LongTensor(all_input_ids[d]).size(0)
			print ("num schools {} for {} split".format(n_schools, d))
			
			for i in range(start_ind, n_schools):
					
				print (d, i)
				count += 1
				
				# Prepare data
				input_ids = torch.LongTensor([all_input_ids[d][i]]).squeeze(0).to(device)
				num_sentences = int(sentences_per_school[d][i])
				label_t = labels_target[d][i].unsqueeze(0).to(device)
				input_mask = torch.tensor([attention_masks[d][i]]).squeeze(0).to(device)
				label_perfrl = torch.tensor([perfrl[d][i]]).to(device)
				label_perwht = torch.tensor([perwht[d][i]]).to(device)
				lable_share_singleparent = torch.tensor([share_singleparent[d][i]]).to(device)
				label_totenrl = torch.tensor([totenrl[d][i]]).to(device)
				label_share_collegeplus = torch.tensor([share_collegeplus[d][i]]).to(device)
				label_mail_returnrate = torch.tensor([mail_returnrate[d][i]]).to(device)

				# Get the prediction for this example
				pred = model(input_ids, num_sentences, attention_mask=input_mask)								
				mse = F.mse_loss(pred[0].unsqueeze_(0), label_t)

				# Generate reference sequence for integrated gradients
				ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
				token_reference = TokenReferenceBase(reference_token_idx=ref_token_id)
				ref_input_ids = token_reference.generate_reference(input_ids.size(0), device=device).unsqueeze(1).repeat(1, input_ids.size(1)).long()

				# Compute integrated gradients
				lig = LayerIntegratedGradients(bert_forward_wrapper, model.bert.embeddings)
				attributions, conv_delta = lig.attribute(
					inputs=input_ids, 
					baselines=ref_input_ids,
					additional_forward_args=(num_sentences, input_mask, 0), 
					internal_batch_size=internal_batch_size,
					n_steps=n_steps,
					return_convergence_delta=True)

				# Sum attributions for each hidden dimension describing a token
				summarized_attr = attributions.sum(dim=-1).squeeze(0)
				n_sent = summarized_attr.size(0)
				attr_for_school_sents = defaultdict(dict)

				# Iterate over sentences and store the attributions per token in each sentence
				for j in range(0, n_sent):
					indices = input_ids[j].detach().squeeze(0).tolist()
					all_tokens = tokenizer.convert_ids_to_tokens(indices)
					attr_for_school_sents[j]['tokens'] = all_tokens
					attr_for_school_sents[j]['attributions'] = summarized_attr[j].tolist()
					assert (len(attr_for_school_sents[j]['tokens']) == len(attr_for_school_sents[j]['attributions']))
				f = open(OUTPUT_FILE.format(BASE_DIR, BEST_MODEL_DIR, i, d, mse), 'w')
				f.write(json.dumps(attr_for_school_sents, indent=4))
				f.close()