Ejemplo n.º 1
0
def longformer_modifier(final_dictionary, tokenizer, attention_window):
    """
    Creates the `global_attention_mask` for the longformer. Tokens with global attention
    attend to all other tokens, and all other tokens attend to them. This is important for
    task-specific finetuning because it makes the model more flexible at representing the
    task. For example, for classification, the `<s>` token should be given global attention.
    For QA, all question tokens should also have global attention. For summarization,
    global attention is given to all of the `<s>` (RoBERTa 'CLS' equivalent) tokens. Please
    refer to the `Longformer paper <https://arxiv.org/abs/2004.05150>`_ for more details. Mask
    values selected in ``[0, 1]``: ``0`` for local attention, ``1`` for global attention.
    """
    # `batch_size` is the number of attention masks (one mask per input sequence)
    batch_size = len(final_dictionary["source_mask"])
    # `sequence_length` is the number of tokens for the first sequence in the batch
    sequence_length = len(final_dictionary["source_mask"][0])
    # create `global_attention_mask` using the above details
    global_attention_mask = torch.tensor([[0] * sequence_length] * batch_size)
    # set the `sent_rep_token_ids` to 1, which is global attention
    for idx, input_sequence in enumerate(final_dictionary["source"]):
        for inner_idx, token_id in enumerate(input_sequence):
            if token_id == tokenizer.cls_token_id:
                global_attention_mask[idx, inner_idx] = 1

    final_dictionary["global_attention_mask"] = global_attention_mask

    for key, item in final_dictionary.items():
        final_dictionary[key] = pad_tensors(
            item,
            nearest_multiple_of=attention_window[0],
        )

    return final_dictionary
Ejemplo n.º 2
0
    def predict(self, input_sequence):
        """Summaries ``input_sequence`` using the model. Can summarize a list of
        sequences at once.

        Args:
            input_sequence (str or list[str]): The text to be summarized.

        Returns:
            str or list[str]: The summary text.
        """
        # If a single string is passed, wrap it in a list so `batch_encode_plus()`
        # processes it correctly
        if type(input_sequence) is str:
            input_sequence = [input_sequence]

        input_sequence_encoded = self.tokenizer.batch_encode_plus(
            input_sequence,
            pad_to_max_length=False,
            truncation=True,
            return_attention_mask=False,
            return_token_type_ids=False,
        )["input_ids"]
        input_sequence_encoded = torch.tensor(input_sequence_encoded)

        # If using the LongformerEncoderDecoder then apply the padding for sliding
        # chunks attention.
        if any(x in self.hparams.model_name_or_path.lower()
               for x in ["led-large", "led-base"]):
            input_sequence_encoded = pad_tensors(
                input_sequence_encoded,
                nearest_multiple_of=self.model.config.attention_window[0],
            )

        t0 = time()
        generated_ids = self.model.generate(
            input_ids=input_sequence_encoded,
            num_beams=3,
            decoder_start_token_id=self.target_boseq_token_id,
            bos_token_id=self.target_boseq_token_id,
            eos_token_id=self.target_eoseq_token_id,
            pad_token_id=self.target_eoseq_token_id,
            max_length=(self.hparams.gen_max_len if self.hparams.gen_max_len
                        else int(self.tokenizer.model_max_length / 2)),
            no_repeat_ngram_size=3,
            use_cache=True,
        )
        generation_time = time() - t0
        logger.debug("Generation Time: %.2f", generation_time)

        generated_ids = generated_ids.tolist()
        prediction = self.ids_to_clean_text(generated_ids)

        return prediction