def _register_hooks(self, alpha: int, embeddings_list: List, token_offsets: List): """ Register a forward hook on the embedding layer which scales the embeddings by alpha. Used for one term in the Integrated Gradients sum. We store the embedding output into the embeddings_list when alpha is zero. This is used later to element-wise multiply the input by the averaged gradients. """ def forward_hook(module, inputs, output): # Save the input for later use. Only do so on first call. if alpha == 0: embeddings_list.append(output.squeeze(0).clone().detach()) # Scale the embedding by alpha output.mul_(alpha) def get_token_offsets(module, inputs, outputs): offsets = util.get_token_offsets_from_text_field_inputs(inputs) if offsets is not None: token_offsets.append(offsets) # Register the hooks handles = [] embedding_layer = util.find_embedding_layer(self.predictor._model) handles.append(embedding_layer.register_forward_hook(forward_hook)) text_field_embedder = util.find_text_field_embedder(self.predictor._model) handles.append(text_field_embedder.register_forward_hook(get_token_offsets)) return handles
def get_interpretable_text_field_embedder(self) -> torch.nn.Module: """ Returns the first `TextFieldEmbedder` of the model. If the predictor wraps around a non-AllenNLP model, this function should be overridden to specify the correct embedder. """ try: return util.find_text_field_embedder(self._model) except RuntimeError: raise RuntimeError( "If the model does not use `TextFieldEmbedder`, please override " "`get_interpretable_text_field_embedder` in your predictor to specify " "the embedding layer.")
def _register_embedding_gradient_hooks(self, embedding_gradients): """ Registers a backward hook on the embedding layer of the model. Used to save the gradients of the embeddings for use in get_gradients() When there are multiple inputs (e.g., a passage and question), the hook will be called multiple times. We append all the embeddings gradients to a list. We additionally add a hook on the _forward_ pass of the model's `TextFieldEmbedder` to save token offsets, if there are any. Having token offsets means that you're using a mismatched token indexer, so we need to aggregate the gradients across wordpieces in a token. We do that with a simple sum. """ def hook_layers(module, grad_in, grad_out): grads = grad_out[0] if self._token_offsets: # If you have a mismatched indexer with multiple TextFields, it's quite possible # that the order we deal with the gradients is wrong. We'll just take items from # the list one at a time, and try to aggregate the gradients. If we got the order # wrong, we should crash, so you'll know about it. If you get an error because of # that, open an issue on github, and we'll see what we can do. The intersection of # multiple TextFields and mismatched indexers is pretty small (currently empty, that # I know of), so we'll ignore this corner case until it's needed. offsets = self._token_offsets.pop(0) span_grads, span_mask = util.batched_span_select(grads.contiguous(), offsets) span_mask = span_mask.unsqueeze(-1) span_grads *= span_mask # zero out paddings span_grads_sum = span_grads.sum(2) span_grads_len = span_mask.sum(2) # Shape: (batch_size, num_orig_tokens, embedding_size) grads = span_grads_sum / torch.clamp_min(span_grads_len, 1) # All the places where the span length is zero, write in zeros. grads[(span_grads_len == 0).expand(grads.shape)] = 0 embedding_gradients.append(grads) def get_token_offsets(module, inputs, outputs): offsets = util.get_token_offsets_from_text_field_inputs(inputs) if offsets is not None: self._token_offsets.append(offsets) hooks = [] text_field_embedder = util.find_text_field_embedder(self._model) hooks.append(text_field_embedder.register_forward_hook(get_token_offsets)) embedding_layer = util.find_embedding_layer(self._model) hooks.append(embedding_layer.register_backward_hook(hook_layers)) return hooks
def _register_hooks(self, embeddings_list: List, token_offsets: List): """ Finds all of the TextFieldEmbedders, and registers a forward hook onto them. When forward() is called, embeddings_list is filled with the embedding values. This is necessary because our normalization scheme multiplies the gradient by the embedding value. """ def forward_hook(module, inputs, output): embeddings_list.append(output.squeeze(0).clone().detach()) def get_token_offsets(module, inputs, outputs): offsets = util.get_token_offsets_from_text_field_inputs(inputs) if offsets is not None: token_offsets.append(offsets) # Register the hooks handles = [] embedding_layer = util.find_embedding_layer(self.predictor._model) handles.append(embedding_layer.register_forward_hook(forward_hook)) text_field_embedder = util.find_text_field_embedder( self.predictor._model) handles.append( text_field_embedder.register_forward_hook(get_token_offsets)) return handles