def longformer_modifier(final_dictionary, tokenizer, attention_window): """ Creates the `global_attention_mask` for the longformer. Tokens with global attention attend to all other tokens, and all other tokens attend to them. This is important for task-specific finetuning because it makes the model more flexible at representing the task. For example, for classification, the `<s>` token should be given global attention. For QA, all question tokens should also have global attention. For summarization, global attention is given to all of the `<s>` (RoBERTa 'CLS' equivalent) tokens. Please refer to the `Longformer paper <https://arxiv.org/abs/2004.05150>`_ for more details. Mask values selected in ``[0, 1]``: ``0`` for local attention, ``1`` for global attention. """ # `batch_size` is the number of attention masks (one mask per input sequence) batch_size = len(final_dictionary["source_mask"]) # `sequence_length` is the number of tokens for the first sequence in the batch sequence_length = len(final_dictionary["source_mask"][0]) # create `global_attention_mask` using the above details global_attention_mask = torch.tensor([[0] * sequence_length] * batch_size) # set the `sent_rep_token_ids` to 1, which is global attention for idx, input_sequence in enumerate(final_dictionary["source"]): for inner_idx, token_id in enumerate(input_sequence): if token_id == tokenizer.cls_token_id: global_attention_mask[idx, inner_idx] = 1 final_dictionary["global_attention_mask"] = global_attention_mask for key, item in final_dictionary.items(): final_dictionary[key] = pad_tensors( item, nearest_multiple_of=attention_window[0], ) return final_dictionary
def predict(self, input_sequence): """Summaries ``input_sequence`` using the model. Can summarize a list of sequences at once. Args: input_sequence (str or list[str]): The text to be summarized. Returns: str or list[str]: The summary text. """ # If a single string is passed, wrap it in a list so `batch_encode_plus()` # processes it correctly if type(input_sequence) is str: input_sequence = [input_sequence] input_sequence_encoded = self.tokenizer.batch_encode_plus( input_sequence, pad_to_max_length=False, truncation=True, return_attention_mask=False, return_token_type_ids=False, )["input_ids"] input_sequence_encoded = torch.tensor(input_sequence_encoded) # If using the LongformerEncoderDecoder then apply the padding for sliding # chunks attention. if any(x in self.hparams.model_name_or_path.lower() for x in ["led-large", "led-base"]): input_sequence_encoded = pad_tensors( input_sequence_encoded, nearest_multiple_of=self.model.config.attention_window[0], ) t0 = time() generated_ids = self.model.generate( input_ids=input_sequence_encoded, num_beams=3, decoder_start_token_id=self.target_boseq_token_id, bos_token_id=self.target_boseq_token_id, eos_token_id=self.target_eoseq_token_id, pad_token_id=self.target_eoseq_token_id, max_length=(self.hparams.gen_max_len if self.hparams.gen_max_len else int(self.tokenizer.model_max_length / 2)), no_repeat_ngram_size=3, use_cache=True, ) generation_time = time() - t0 logger.debug("Generation Time: %.2f", generation_time) generated_ids = generated_ids.tolist() prediction = self.ids_to_clean_text(generated_ids) return prediction