Exemplo n.º 1
0
    def _register_hooks(self, alpha: int, embeddings_list: List, token_offsets: List):
        """
        Register a forward hook on the embedding layer which scales the embeddings by alpha. Used
        for one term in the Integrated Gradients sum.

        We store the embedding output into the embeddings_list when alpha is zero.  This is used
        later to element-wise multiply the input by the averaged gradients.
        """

        def forward_hook(module, inputs, output):
            # Save the input for later use. Only do so on first call.
            if alpha == 0:
                embeddings_list.append(output.squeeze(0).clone().detach())

            # Scale the embedding by alpha
            output.mul_(alpha)

        def get_token_offsets(module, inputs, outputs):
            offsets = util.get_token_offsets_from_text_field_inputs(inputs)
            if offsets is not None:
                token_offsets.append(offsets)

        # Register the hooks
        handles = []
        embedding_layer = util.find_embedding_layer(self.predictor._model)
        handles.append(embedding_layer.register_forward_hook(forward_hook))
        text_field_embedder = util.find_text_field_embedder(self.predictor._model)
        handles.append(text_field_embedder.register_forward_hook(get_token_offsets))
        return handles
Exemplo n.º 2
0
 def get_interpretable_text_field_embedder(self) -> torch.nn.Module:
     """
     Returns the first `TextFieldEmbedder` of the model.
     If the predictor wraps around a non-AllenNLP model,
     this function should be overridden to specify the correct embedder.
     """
     try:
         return util.find_text_field_embedder(self._model)
     except RuntimeError:
         raise RuntimeError(
             "If the model does not use `TextFieldEmbedder`, please override "
             "`get_interpretable_text_field_embedder` in your predictor to specify "
             "the embedding layer.")
Exemplo n.º 3
0
    def _register_embedding_gradient_hooks(self, embedding_gradients):
        """
        Registers a backward hook on the embedding layer of the model.  Used to save the gradients
        of the embeddings for use in get_gradients()

        When there are multiple inputs (e.g., a passage and question), the hook
        will be called multiple times. We append all the embeddings gradients
        to a list.

        We additionally add a hook on the _forward_ pass of the model's `TextFieldEmbedder` to save
        token offsets, if there are any.  Having token offsets means that you're using a mismatched
        token indexer, so we need to aggregate the gradients across wordpieces in a token.  We do
        that with a simple sum.
        """

        def hook_layers(module, grad_in, grad_out):
            grads = grad_out[0]
            if self._token_offsets:
                # If you have a mismatched indexer with multiple TextFields, it's quite possible
                # that the order we deal with the gradients is wrong.  We'll just take items from
                # the list one at a time, and try to aggregate the gradients.  If we got the order
                # wrong, we should crash, so you'll know about it.  If you get an error because of
                # that, open an issue on github, and we'll see what we can do.  The intersection of
                # multiple TextFields and mismatched indexers is pretty small (currently empty, that
                # I know of), so we'll ignore this corner case until it's needed.
                offsets = self._token_offsets.pop(0)
                span_grads, span_mask = util.batched_span_select(grads.contiguous(), offsets)
                span_mask = span_mask.unsqueeze(-1)
                span_grads *= span_mask  # zero out paddings

                span_grads_sum = span_grads.sum(2)
                span_grads_len = span_mask.sum(2)
                # Shape: (batch_size, num_orig_tokens, embedding_size)
                grads = span_grads_sum / torch.clamp_min(span_grads_len, 1)

                # All the places where the span length is zero, write in zeros.
                grads[(span_grads_len == 0).expand(grads.shape)] = 0

            embedding_gradients.append(grads)

        def get_token_offsets(module, inputs, outputs):
            offsets = util.get_token_offsets_from_text_field_inputs(inputs)
            if offsets is not None:
                self._token_offsets.append(offsets)

        hooks = []
        text_field_embedder = util.find_text_field_embedder(self._model)
        hooks.append(text_field_embedder.register_forward_hook(get_token_offsets))
        embedding_layer = util.find_embedding_layer(self._model)
        hooks.append(embedding_layer.register_backward_hook(hook_layers))
        return hooks
Exemplo n.º 4
0
    def _register_hooks(self, embeddings_list: List, token_offsets: List):
        """
        Finds all of the TextFieldEmbedders, and registers a forward hook onto them. When forward()
        is called, embeddings_list is filled with the embedding values. This is necessary because
        our normalization scheme multiplies the gradient by the embedding value.
        """
        def forward_hook(module, inputs, output):
            embeddings_list.append(output.squeeze(0).clone().detach())

        def get_token_offsets(module, inputs, outputs):
            offsets = util.get_token_offsets_from_text_field_inputs(inputs)
            if offsets is not None:
                token_offsets.append(offsets)

        # Register the hooks
        handles = []
        embedding_layer = util.find_embedding_layer(self.predictor._model)
        handles.append(embedding_layer.register_forward_hook(forward_hook))
        text_field_embedder = util.find_text_field_embedder(
            self.predictor._model)
        handles.append(
            text_field_embedder.register_forward_hook(get_token_offsets))
        return handles