def get_averaged_grad( model: GPT2LMHeadModel, trigger_tokens: torch.Tensor, target_tokens: torch.Tensor, targets_embeddings: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Return the gradient of the trigger tokens wrt to the loss averaged across all the targets. """ num_targets = target_tokens.shape[0] trigger_length = trigger_tokens.shape[0] targets_padding_mask = target_tokens.eq(-1) model_input_embeddings = model.get_input_embeddings() trigger_embeddings = ( model_input_embeddings(trigger_tokens).detach().requires_grad_(True) ) if targets_embeddings is None: target_inputs = target_tokens.clone() target_inputs[targets_padding_mask] = 1 targets_embeddings = model_input_embeddings(target_inputs) lm_input = torch.cat( [ trigger_embeddings.unsqueeze(0).expand( num_targets, *trigger_embeddings.shape ), targets_embeddings, ], dim=1, ) model.zero_grad() attention_mask = torch.cat( [ torch.ones( (num_targets, trigger_length), device=target_tokens.device, dtype=torch.bool, ), targets_padding_mask.logical_not(), ], dim=1, ) lm_output = model(inputs_embeds=lm_input, attention_mask=attention_mask) logits = lm_output[0] target_logits = logits[:, trigger_tokens.shape[0] - 1 : -1, :].reshape( num_targets * target_tokens.shape[1], -1 ) loss = torch.nn.functional.cross_entropy( target_logits, target_tokens.view(-1), ignore_index=-100 ) loss.backward() embeddings_average_grad = trigger_embeddings.grad.detach() model.zero_grad() return embeddings_average_grad
def __init__(self, model: GPT2LMHeadModel, targets: torch.Tensor): super().__init__() self._transformer = model.transformer self.input_embeddings = model.get_input_embeddings() self.lm_head = model.lm_head # At this point we always use the same target, so we might as well cache all of these self.targets: Final[torch.Tensor] = targets self.flat_targets: Final[torch.Tensor] = targets.reshape(-1) self.targets_padding_mask: Final[torch.Tensor] = targets.eq(-1) # `target` is padded with `-1`s at the end of the sequence dimension. This is good when # using them for labels as `-1` in labels are ignored in the loss. However, in inputs, # `-1` is not a valid id, so we put 1 in their places, which will result in useless # embeddings, which should not be an issue since they won't go in the loss, capisce? target_indices_for_embeddings = self.targets.clone() target_indices_for_embeddings[self.targets_padding_mask] = 1 self.targets_embeddings: Final[torch.Tensor] = self.input_embeddings( target_indices_for_embeddings ) self.num_targets = targets.shape[0] self.max_target_length = targets.shape[1]