def get_averaged_grad(
    model: GPT2LMHeadModel,
    trigger_tokens: torch.Tensor,
    target_tokens: torch.Tensor,
    targets_embeddings: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Return the gradient of the trigger tokens wrt to the loss averaged across all the targets.
    """
    num_targets = target_tokens.shape[0]
    trigger_length = trigger_tokens.shape[0]
    targets_padding_mask = target_tokens.eq(-1)

    model_input_embeddings = model.get_input_embeddings()
    trigger_embeddings = (
        model_input_embeddings(trigger_tokens).detach().requires_grad_(True)
    )
    if targets_embeddings is None:
        target_inputs = target_tokens.clone()
        target_inputs[targets_padding_mask] = 1
        targets_embeddings = model_input_embeddings(target_inputs)
    lm_input = torch.cat(
        [
            trigger_embeddings.unsqueeze(0).expand(
                num_targets, *trigger_embeddings.shape
            ),
            targets_embeddings,
        ],
        dim=1,
    )
    model.zero_grad()
    attention_mask = torch.cat(
        [
            torch.ones(
                (num_targets, trigger_length),
                device=target_tokens.device,
                dtype=torch.bool,
            ),
            targets_padding_mask.logical_not(),
        ],
        dim=1,
    )
    lm_output = model(inputs_embeds=lm_input, attention_mask=attention_mask)
    logits = lm_output[0]
    target_logits = logits[:, trigger_tokens.shape[0] - 1 : -1, :].reshape(
        num_targets * target_tokens.shape[1], -1
    )
    loss = torch.nn.functional.cross_entropy(
        target_logits, target_tokens.view(-1), ignore_index=-100
    )
    loss.backward()
    embeddings_average_grad = trigger_embeddings.grad.detach()
    model.zero_grad()
    return embeddings_average_grad
    def __init__(self, model: GPT2LMHeadModel, targets: torch.Tensor):
        super().__init__()
        self._transformer = model.transformer
        self.input_embeddings = model.get_input_embeddings()
        self.lm_head = model.lm_head

        # At this point we always use the same target, so we might as well cache all of these
        self.targets: Final[torch.Tensor] = targets
        self.flat_targets: Final[torch.Tensor] = targets.reshape(-1)
        self.targets_padding_mask: Final[torch.Tensor] = targets.eq(-1)
        # `target` is padded with `-1`s at the end of the sequence dimension. This is good when
        # using them for labels as `-1` in labels are ignored in the loss. However, in inputs,
        # `-1` is not a valid id, so we put 1 in their places, which will result in useless
        # embeddings, which should not be an issue since they won't go in the loss, capisce?
        target_indices_for_embeddings = self.targets.clone()
        target_indices_for_embeddings[self.targets_padding_mask] = 1
        self.targets_embeddings: Final[torch.Tensor] = self.input_embeddings(
            target_indices_for_embeddings
        )

        self.num_targets = targets.shape[0]
        self.max_target_length = targets.shape[1]