def __init__( self, model_name: str, vocab: Vocabulary, indexer: PretrainedTransformerIndexer = None, max_decoding_steps: int = 140, beam_size: int = 4, encoder: Seq2SeqEncoder = None, ): """ # Parameters model_name : `str`, required Name of the pre-trained BART model to use. Available options can be found in `transformers.models.bart.modeling_bart.BART_PRETRAINED_MODEL_ARCHIVE_MAP`. vocab : `Vocabulary`, required Vocabulary containing source and target vocabularies. indexer : `PretrainedTransformerIndexer`, optional (default = `None`) Indexer to be used for converting decoded sequences of ids to to sequences of tokens. max_decoding_steps : `int`, optional (default = `128`) Number of decoding steps during beam search. beam_size : `int`, optional (default = `5`) Number of beams to use in beam search. The default is from the BART paper. encoder : `Seq2SeqEncoder`, optional (default = `None`) Encoder to used in BART. By default, the original BART encoder is used. """ super().__init__(vocab) self.bart = BartForConditionalGeneration.from_pretrained(model_name) self._indexer = indexer or PretrainedTransformerIndexer( model_name, namespace="tokens") self._start_id = self.bart.config.bos_token_id # CLS self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id self._end_id = self.bart.config.eos_token_id # SEP self._pad_id = self.bart.config.pad_token_id # PAD self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_id, max_steps=max_decoding_steps, beam_size=beam_size or 1) self._rouge = ROUGE( exclude_indices={self._start_id, self._pad_id, self._end_id}) self._bleu = BLEU( exclude_indices={self._start_id, self._pad_id, self._end_id}) # Replace bart encoder with given encoder. We need to extract the two embedding layers so that # we can use them in the encoder wrapper if encoder is not None: assert (encoder.get_input_dim() == encoder.get_output_dim() == self.bart.config.hidden_size) self.bart.model.encoder = _BartEncoderWrapper( encoder, self.bart.model.encoder.embed_tokens, self.bart.model.encoder.embed_positions, )
def __init__( self, model_name: str, vocab: Vocabulary, beam_search: Lazy[BeamSearch] = Lazy(BeamSearch), indexer: PretrainedTransformerIndexer = None, encoder: Seq2SeqEncoder = None, **kwargs, ): super().__init__(vocab) self.bart = BartForConditionalGeneration.from_pretrained(model_name) self._indexer = indexer or PretrainedTransformerIndexer(model_name, namespace="tokens") self._start_id = self.bart.config.bos_token_id # CLS self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id self._end_id = self.bart.config.eos_token_id # SEP self._pad_id = self.bart.config.pad_token_id # PAD # At prediction time, we'll use a beam search to find the best target sequence. # For backwards compatibility, check if beam_size or max_decoding_steps were passed in as # kwargs. If so, update the BeamSearch object before constructing and raise a DeprecationWarning deprecation_warning = ( "The parameter {} has been deprecated." " Provide this parameter as argument to beam_search instead." ) beam_search_extras = {} if "beam_size" in kwargs: beam_search_extras["beam_size"] = kwargs["beam_size"] warnings.warn(deprecation_warning.format("beam_size"), DeprecationWarning) if "max_decoding_steps" in kwargs: beam_search_extras["max_steps"] = kwargs["max_decoding_steps"] warnings.warn(deprecation_warning.format("max_decoding_steps"), DeprecationWarning) self._beam_search = beam_search.construct( end_index=self._end_id, vocab=self.vocab, **beam_search_extras ) self._rouge = ROUGE(exclude_indices={self._start_id, self._pad_id, self._end_id}) self._bleu = BLEU(exclude_indices={self._start_id, self._pad_id, self._end_id}) # Replace bart encoder with given encoder. We need to extract the two embedding layers so that # we can use them in the encoder wrapper if encoder is not None: assert ( encoder.get_input_dim() == encoder.get_output_dim() == self.bart.config.hidden_size ) self.bart.model.encoder = _BartEncoderWrapper( encoder, self.bart.model.encoder.embed_tokens, self.bart.model.encoder.embed_positions, )
def __init__(self, vocab: Vocabulary, model_name: str, beam_search: Lazy[BeamSearch] = Lazy(BeamSearch, beam_size=3, max_steps=50), checkpoint_wrapper: Optional[CheckpointWrapper] = None, weights_path: Optional[Union[str, PathLike]] = None, **kwargs) -> None: super().__init__(vocab, **kwargs) self._model_name = model_name # We only instantiate this when we need it. self._tokenizer: Optional[PretrainedTransformerTokenizer] = None self.t5 = T5Module.from_pretrained_module( model_name, beam_search=beam_search, ddp_accelerator=self.ddp_accelerator, checkpoint_wrapper=checkpoint_wrapper, weights_path=weights_path, ) exclude_indices = { self.t5.pad_token_id, self.t5.decoder_start_token_id, self.t5.eos_token_id, } self._metrics = [ ROUGE(exclude_indices=exclude_indices), BLEU(exclude_indices=exclude_indices), ]
def global_distributed_rouge( global_rank: int, world_size: int, gpu_id: Union[int, torch.device], metric: ROUGE, metric_kwargs: Dict[str, Any], desired_values: Dict[str, Any], ): kwargs = {} # Use the arguments meant for the process with rank `global_rank`. for argname in metric_kwargs: kwargs[argname] = metric_kwargs[argname][global_rank] metric(**kwargs) metrics = metric.get_metric() # Unigram unigram_recall = metric._total_rouge_n_recalls[1] assert_allclose(unigram_recall, desired_values["unigram_recall"]) unigram_precision = metric._total_rouge_n_precisions[1] assert_allclose(unigram_precision, desired_values["unigram_precision"]) unigram_f1 = metric._total_rouge_n_f1s[1] assert_allclose(unigram_f1, desired_values["unigram_f1"]) assert metrics[ "ROUGE-1_R"] == unigram_recall / metric._total_sequence_count assert metrics[ "ROUGE-1_P"] == unigram_precision / metric._total_sequence_count assert metrics["ROUGE-1_F1"] == unigram_f1 / metric._total_sequence_count # Bigram bigram_recall = metric._total_rouge_n_recalls[2] assert_allclose(bigram_recall, desired_values["bigram_recall"]) bigram_precision = metric._total_rouge_n_precisions[2] assert_allclose(bigram_precision, desired_values["bigram_precision"]) bigram_f1 = metric._total_rouge_n_f1s[2] assert_allclose(bigram_f1, desired_values["bigram_f1"]) assert metrics["ROUGE-2_R"] == bigram_recall / metric._total_sequence_count assert metrics[ "ROUGE-2_P"] == bigram_precision / metric._total_sequence_count assert metrics["ROUGE-2_F1"] == bigram_f1 / metric._total_sequence_count # ROUGE-L assert_allclose(metric._total_rouge_l_f1, desired_values["total_rouge_l_f1"]) assert metrics[ "ROUGE-L"] == metric._total_rouge_l_f1 / metric._total_sequence_count
def __init__(self, vocab: Vocabulary, pretrained_model_path, beam_size=5, max_decoding_steps=140, indexer=None): super().__init__(vocab) self.plm = MT5ForConditionalGeneration.from_pretrained(pretrained_model_path) self._indexer = indexer or PretrainedTransformerIndexer(pretrained_model_path, namespace="tokens") ## self._start_id = self.plm.config.decoder_start_token_id ## self._end_id = self.plm.config.eos_token_id # self._decoder_start_id = self.plm.config.decoder_start_token_id self._end_id = self.plm.config.eos_token_id # self._pad_id = self.plm.config.pad_token_id # self._beam_search = BeamSearch( self._end_id, max_steps=max_decoding_steps, beam_size=beam_size or 1 ) self._rouge = ROUGE(exclude_indices={self._start_id, self._pad_id, self._end_id}) self._bleu = BLEU(exclude_indices={self._start_id, self._pad_id, self._end_id})
def __init__( self, model_name: str, vocab: Vocabulary, indexer: PretrainedTransformerIndexer = None, max_decoding_steps: int = 140, beam_size: int = 4, encoder: Seq2SeqEncoder = None, ): super().__init__(vocab) self.bart = BartForConditionalGeneration.from_pretrained(model_name) self._indexer = indexer or PretrainedTransformerIndexer( model_name, namespace="tokens") self._start_id = self.bart.config.bos_token_id # CLS self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id self._end_id = self.bart.config.eos_token_id # SEP self._pad_id = self.bart.config.pad_token_id # PAD self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_id, max_steps=max_decoding_steps, beam_size=beam_size or 1) self._rouge = ROUGE( exclude_indices={self._start_id, self._pad_id, self._end_id}) self._bleu = BLEU( exclude_indices={self._start_id, self._pad_id, self._end_id}) # Replace bart encoder with given encoder. We need to extract the two embedding layers so that # we can use them in the encoder wrapper if encoder is not None: assert (encoder.get_input_dim() == encoder.get_output_dim() == self.bart.config.hidden_size) self.bart.model.encoder = _BartEncoderWrapper( encoder, self.bart.model.encoder.embed_tokens, self.bart.model.encoder.embed_positions, )
def __init__(self, vocab: Vocabulary, model_name: str, **kwargs) -> None: super().__init__(vocab, **kwargs) self._model_name = model_name # We only instantiate this when we need it. self._tokenizer: Optional[PretrainedTransformerTokenizer] = None self.t5 = T5Module.from_pretrained_module(model_name) exclude_indices = { self.t5.pad_token_id, self.t5.decoder_start_token_id, self.t5.eos_token_id, } self._metrics = [ ROUGE(exclude_indices=exclude_indices), BLEU(exclude_indices=exclude_indices), ]
def test_distributed_rouge(self): predictions = [ torch.tensor([[1, 0, 1, 2], [1, 0, 3, 0]]), torch.tensor([[1, 2, 3, 0]]) ] targets = [ torch.tensor([[2, 0, 1, 2], [1, 2, 1, 0]]), torch.tensor([[1, 0, 2, 3]]) ] metric_kwargs = {"predictions": predictions, "gold_targets": targets} desired_values = {} desired_values["unigram_recall"] = 2 / 3 + 1 / 3 + 3 / 3 desired_values["unigram_precision"] = 2 / 3 + 1 / 2 + 3 / 3 desired_values["unigram_f1"] = (self.f1(2 / 3, 2 / 3) + self.f1(1 / 2, 1 / 3) + self.f1(3 / 3, 3 / 3)) desired_values["bigram_recall"] = 1 / 1 + 0 / 2 + 1 / 1 desired_values["bigram_precision"] = 1 / 1 + 0 + 1 / 2 desired_values["bigram_f1"] = (self.f1(1 / 1, 1 / 1) + self.f1(0, 0 / 2) + self.f1(1 / 2, 1 / 1)) desired_values["total_rouge_l_f1"] = (self.f1(2 / 3, 2 / 3) + self.f1(1 / 3, 1 / 2) + self.f1(3 / 3, 3 / 3)) run_distributed_test( [-1, -1], global_distributed_rouge, ROUGE(exclude_indices={0}), metric_kwargs, desired_values, )
class Bart(Model): """ BART model from the paper "BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension" (https://arxiv.org/abs/1910.13461). The Bart model here uses a language modeling head and thus can be used for text generation. """ def __init__( self, model_name: str, vocab: Vocabulary, indexer: PretrainedTransformerIndexer = None, max_decoding_steps: int = 140, beam_size: int = 4, encoder: Seq2SeqEncoder = None, ): """ # Parameters model_name : `str`, required Name of the pre-trained BART model to use. Available options can be found in `transformers.models.bart.modeling_bart.BART_PRETRAINED_MODEL_ARCHIVE_MAP`. vocab : `Vocabulary`, required Vocabulary containing source and target vocabularies. indexer : `PretrainedTransformerIndexer`, optional (default = `None`) Indexer to be used for converting decoded sequences of ids to to sequences of tokens. max_decoding_steps : `int`, optional (default = `128`) Number of decoding steps during beam search. beam_size : `int`, optional (default = `5`) Number of beams to use in beam search. The default is from the BART paper. encoder : `Seq2SeqEncoder`, optional (default = `None`) Encoder to used in BART. By default, the original BART encoder is used. """ super().__init__(vocab) self.bart = BartForConditionalGeneration.from_pretrained(model_name) self._indexer = indexer or PretrainedTransformerIndexer( model_name, namespace="tokens") self._start_id = self.bart.config.bos_token_id # CLS self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id self._end_id = self.bart.config.eos_token_id # SEP self._pad_id = self.bart.config.pad_token_id # PAD self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_id, max_steps=max_decoding_steps, beam_size=beam_size or 1) self._rouge = ROUGE( exclude_indices={self._start_id, self._pad_id, self._end_id}) self._bleu = BLEU( exclude_indices={self._start_id, self._pad_id, self._end_id}) # Replace bart encoder with given encoder. We need to extract the two embedding layers so that # we can use them in the encoder wrapper if encoder is not None: assert (encoder.get_input_dim() == encoder.get_output_dim() == self.bart.config.hidden_size) self.bart.model.encoder = _BartEncoderWrapper( encoder, self.bart.model.encoder.embed_tokens, self.bart.model.encoder.embed_positions, ) @overrides def forward( self, source_tokens: TextFieldTensors, target_tokens: TextFieldTensors = None) -> Dict[str, torch.Tensor]: """ Performs the forward step of Bart. # Parameters source_tokens : `TextFieldTensors`, required The source tokens for the encoder. We assume they are stored under the `tokens` key. target_tokens : `TextFieldTensors`, optional (default = `None`) The target tokens for the decoder. We assume they are stored under the `tokens` key. If no target tokens are given, the source tokens are shifted to the right by 1. # Returns `Dict[str, torch.Tensor]` During training, this dictionary contains the `decoder_logits` of shape `(batch_size, max_target_length, target_vocab_size)` and the `loss`. During inference, it contains `predictions` of shape `(batch_size, max_decoding_steps)` and `log_probabilities` of shape `(batch_size,)`. """ inputs = source_tokens targets = target_tokens input_ids, input_mask = inputs["tokens"]["token_ids"], inputs[ "tokens"]["mask"] outputs = {} # If no targets are provided, then shift input to right by 1. Bart already does this internally # but it does not use them for loss calculation. if targets is not None: target_ids, target_mask = targets["tokens"]["token_ids"], targets[ "tokens"]["mask"] else: target_ids = input_ids[:, 1:] target_mask = input_mask[:, 1:] if self.training: decoder_logits = self.bart( input_ids=input_ids, attention_mask=input_mask, decoder_input_ids=target_ids[:, :-1].contiguous(), decoder_attention_mask=target_mask[:, :-1].contiguous(), use_cache=False, )[0] outputs["decoder_logits"] = decoder_logits # The BART paper mentions label smoothing of 0.1 for sequence generation tasks outputs["loss"] = sequence_cross_entropy_with_logits( decoder_logits, target_ids[:, 1:].contiguous(), target_mask[:, 1:].contiguous(), label_smoothing=0.1, average="token", ) else: # Use decoder start id and start of sentence to start decoder initial_decoder_ids = torch.tensor( [[self._decoder_start_id, self._start_id]], dtype=input_ids.dtype, device=input_ids.device, ).repeat(input_ids.shape[0], 1) inital_state = { "input_ids": input_ids, "input_mask": input_mask, "encoder_states": None, } beam_result = self._beam_search.search(initial_decoder_ids, inital_state, self.take_step) predictions = beam_result[0] max_pred_indices = (beam_result[1].argmax(dim=-1).view( -1, 1, 1).expand(-1, -1, predictions.shape[-1])) predictions = predictions.gather( dim=1, index=max_pred_indices).squeeze(dim=1) self._rouge(predictions, target_ids) self._bleu(predictions, target_ids) outputs["predictions"] = predictions outputs["log_probabilities"] = (beam_result[1].gather( dim=-1, index=max_pred_indices[..., 0]).squeeze(dim=-1)) self.make_output_human_readable(outputs) return outputs @staticmethod def _decoder_cache_to_dict(decoder_cache): cache_dict = {} for layer_index, layer_cache in enumerate(decoder_cache): for attention_name, attention_cache in layer_cache.items(): for tensor_name, cache_value in attention_cache.items(): key = (layer_index, attention_name, tensor_name) cache_dict[key] = cache_value return cache_dict @staticmethod def _dict_to_decoder_cache(cache_dict): decoder_cache = [] for key, cache_value in cache_dict.items(): # Split key and extract index and dict keys layer_idx, attention_name, tensor_name = key # Extend decoder_cache to fit layer_idx + 1 layers decoder_cache = decoder_cache + [ {} for _ in range(layer_idx + 1 - len(decoder_cache)) ] cache = decoder_cache[layer_idx] if attention_name not in cache: cache[attention_name] = {} assert tensor_name not in cache[attention_name] cache[attention_name][tensor_name] = cache_value return decoder_cache def take_step(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], step: int) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take step during beam search. # Parameters last_predictions : `torch.Tensor` The predicted token ids from the previous step. Shape: `(group_size,)` state : `Dict[str, torch.Tensor]` State required to generate next set of predictions step : `int` The time step in beam search decoding. # Returns `Tuple[torch.Tensor, Dict[str, torch.Tensor]]` A tuple containing logits for the next tokens of shape `(group_size, target_vocab_size)` and an updated state dictionary. """ if len(last_predictions.shape) == 1: last_predictions = last_predictions.unsqueeze(-1) # Only the last predictions are needed for the decoder, but we need to pad the decoder ids # to not mess up the positional embeddings in the decoder. padding_size = 0 if step > 0: padding_size = step + 1 padding = torch.full( (last_predictions.shape[0], padding_size), self._pad_id, dtype=last_predictions.dtype, device=last_predictions.device, ) last_predictions = torch.cat([padding, last_predictions], dim=-1) decoder_cache = None decoder_cache_dict = { k: (state[k].contiguous() if state[k] is not None else None) for k in state if k not in {"input_ids", "input_mask", "encoder_states"} } if len(decoder_cache_dict) != 0: decoder_cache = self._dict_to_decoder_cache(decoder_cache_dict) log_probabilities = None for i in range(padding_size, last_predictions.shape[1]): encoder_outputs = ((state["encoder_states"], ) if state["encoder_states"] is not None else None) outputs = self.bart( input_ids=state["input_ids"], attention_mask=state["input_mask"], encoder_outputs=encoder_outputs, decoder_input_ids=last_predictions[:, :i + 1], past_key_values=decoder_cache, use_cache=True, ) decoder_log_probabilities = F.log_softmax(outputs[0][:, 0], dim=-1) if log_probabilities is None: log_probabilities = decoder_log_probabilities else: idx = last_predictions[:, i].view(-1, 1) log_probabilities = decoder_log_probabilities + log_probabilities.gather( dim=-1, index=idx) decoder_cache = outputs[1] state["encoder_states"] = outputs[2] if decoder_cache is not None: decoder_cache_dict = self._decoder_cache_to_dict(decoder_cache) state.update(decoder_cache_dict) return log_probabilities, state @overrides def make_output_human_readable( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: """ # Parameters output_dict : `Dict[str, torch.Tensor]` A dictionary containing a batch of predictions with key `predictions`. The tensor should have shape `(batch_size, max_sequence_length)` # Returns `Dict[str, Any]` Original `output_dict` with an additional `predicted_tokens` key that maps to a list of lists of tokens. """ predictions = output_dict["predictions"] predicted_tokens = [None] * predictions.shape[0] for i in range(predictions.shape[0]): predicted_tokens[i] = self._indexer.indices_to_tokens( {"token_ids": predictions[i].tolist()}, self.vocab) output_dict["predicted_tokens"] = predicted_tokens return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics: Dict[str, float] = {} if not self.training: metrics.update(self._rouge.get_metric(reset=reset)) metrics.update(self._bleu.get_metric(reset=reset)) return metrics
class Seq2seqPlmsGenerator(Model): def __init__(self, vocab: Vocabulary, pretrained_model_path, beam_size=5, max_decoding_steps=140, indexer=None): super().__init__(vocab) self.plm = MT5ForConditionalGeneration.from_pretrained(pretrained_model_path) self._indexer = indexer or PretrainedTransformerIndexer(pretrained_model_path, namespace="tokens") ## self._start_id = self.plm.config.decoder_start_token_id ## self._end_id = self.plm.config.eos_token_id # self._decoder_start_id = self.plm.config.decoder_start_token_id self._end_id = self.plm.config.eos_token_id # self._pad_id = self.plm.config.pad_token_id # self._beam_search = BeamSearch( self._end_id, max_steps=max_decoding_steps, beam_size=beam_size or 1 ) self._rouge = ROUGE(exclude_indices={self._start_id, self._pad_id, self._end_id}) self._bleu = BLEU(exclude_indices={self._start_id, self._pad_id, self._end_id}) @overrides def forward(self, source_tokens, target_tokens=None) -> Dict[str, torch.Tensor]: inputs = source_tokens targets = target_tokens input_ids, input_mask = inputs["tokens"]["token_ids"], inputs["tokens"]["mask"] outputs = {} # If no targets are provided, then shift input to right by 1. Bart already does this internally # but it does not use them for loss calculation. if targets is not None: target_ids, target_mask = targets["tokens"]["token_ids"], targets["tokens"]["mask"] else: target_ids = input_ids[:, 1:] target_mask = input_mask[:, 1:] if self.training: # training outputs = self.plm(input_ids=input_ids, attention_mask=input_mask, decoder_input_ids=target_ids[:, :-1].contiguous(), decoder_attention_mask=target_mask[:, :-1].contiguous(), use_cache=False, return_dict=True) outputs["decoder_logits"] = outputs.logits outputs["loss"] = sequence_cross_entropy_with_logits( outputs.logits, cast(torch.LongTensor, target_ids[:, 1:].contiguous()), cast(torch.BoolTensor, target_mask[:, 1:].contiguous()), label_smoothing=0.1, average="token", ) elif targets is not None: # validation outputs = self.plm(input_ids=input_ids, attention_mask=input_mask, decoder_input_ids=target_ids[:, :-1].contiguous(), decoder_attention_mask=target_mask[:, :-1].contiguous(), use_cache=False, return_dict=True) outputs["decoder_logits"] = outputs.logits outputs["loss"] = sequence_cross_entropy_with_logits( outputs.logits, cast(torch.LongTensor, target_ids[:, 1:].contiguous()), cast(torch.BoolTensor, target_mask[:, 1:].contiguous()), label_smoothing=0.1, ) self._rouge(torch.argmax(outputs.logits, -1), target_ids) self._bleu(torch.argmax(outputs.logits, -1), target_ids) else: #prediction # Use decoder start id and start of sentence to start decoder initial_decoder_ids = torch.tensor( [[self._decoder_start_id]], dtype=input_ids.dtype, device=input_ids.device, ).repeat(input_ids.shape[0], 1) inital_state = { "input_ids": input_ids, "input_mask": input_mask, } beam_result = self._beam_search.search( initial_decoder_ids, inital_state, self.take_step ) predictions = beam_result[0] logger.info(beam_result) max_pred_indices = ( beam_result[1].argmax(dim=-1).view(-1, 1, 1).expand(-1, -1, predictions.shape[-1]) ) predictions = predictions.gather(dim=1, index=max_pred_indices).squeeze(dim=1) self._rouge(predictions, target_ids) self._bleu(predictions, target_ids) outputs["predictions"] = predictions outputs["log_probabilities"] = ( beam_result[1].gather(dim=-1, index=max_pred_indices[..., 0]).squeeze(dim=-1) ) self.make_output_human_readable(outputs) return outputs @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics: Dict[str, float] = {} if not self.training: metrics.update(self._rouge.get_metric(reset=reset)) metrics.update(self._bleu.get_metric(reset=reset)) return metrics @staticmethod def _decoder_cache_to_dict(decoder_cache: DecoderCacheType) -> Dict[str, torch.Tensor]: cache_dict = {} for layer_index, layer_cache in enumerate(decoder_cache): # Each layer caches the key and value tensors for its self-attention and cross-attention. # Hence the `layer_cache` tuple has 4 elements. assert len(layer_cache) == 4 for tensor_index, tensor in enumerate(layer_cache): key = f"decoder_cache_{layer_index}_{tensor_index}" cache_dict[key] = tensor return cache_dict def _dict_to_decoder_cache(self, cache_dict: Dict[str, torch.Tensor]) -> DecoderCacheType: decoder_cache = [] for layer_index in range(self.plm.config.num_layers): base_key = f"decoder_cache_{layer_index}_" layer_cache = ( cache_dict[base_key + "0"], cache_dict[base_key + "1"], cache_dict[base_key + "2"], cache_dict[base_key + "3"], ) decoder_cache.append(layer_cache) assert decoder_cache return tuple(decoder_cache) def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], step: int ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take step during beam search. # Parameters last_predictions : `torch.Tensor` The predicted token ids from the previous step. Shape: `(group_size,)` state : `Dict[str, torch.Tensor]` State required to generate next set of predictions step : `int` The time step in beam search decoding. # Returns `Tuple[torch.Tensor, Dict[str, torch.Tensor]]` A tuple containing logits for the next tokens of shape `(group_size, target_vocab_size)` and an updated state dictionary. """ if len(last_predictions.shape) == 1: last_predictions = last_predictions.unsqueeze(-1) decoder_cache = None decoder_cache_dict = { k: state[k].contiguous() for k in state if k not in {"input_ids", "input_mask", "encoder_states"} } if len(decoder_cache_dict) != 0: decoder_cache = self._dict_to_decoder_cache(decoder_cache_dict) encoder_outputs = (state["encoder_states"],) if "encoder_states" in state else None outputs = self.plm( input_ids=state["input_ids"] if encoder_outputs is None else None, attention_mask=state["input_mask"], encoder_outputs=encoder_outputs, decoder_input_ids=last_predictions, past_key_values=decoder_cache, use_cache=True, return_dict=True, ) logits = outputs.logits[:, -1, :] log_probabilities = F.log_softmax(logits, dim=-1) decoder_cache = outputs.past_key_values if decoder_cache is not None: decoder_cache_dict = self._decoder_cache_to_dict(decoder_cache) state.update(decoder_cache_dict) state["encoder_states"] = outputs.encoder_last_hidden_state return log_probabilities, state @overrides def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: """ # Parameters output_dict : `Dict[str, torch.Tensor]` A dictionary containing a batch of predictions with key `predictions`. The tensor should have shape `(batch_size, max_sequence_length)` # Returns `Dict[str, Any]` Original `output_dict` with an additional `predicted_tokens` key that maps to a list of lists of tokens. """ predictions = output_dict["predictions"] predicted_tokens = [None] * predictions.shape[0] for i in range(predictions.shape[0]): predicted_tokens[i] = self._indexer.indices_to_tokens( {"token_ids": predictions[i].tolist()}, self.vocab, ) output_dict["predicted_tokens"] = predicted_tokens # type: ignore output_dict["predicted_text"] = self._indexer._tokenizer.batch_decode( predictions.tolist(), skip_special_tokens=True ) return output_dict
def setup_method(self): super().setup_method() self.metric = ROUGE(exclude_indices={0})
class RougeTest(AllenNlpTestCase): def setup_method(self): super().setup_method() self.metric = ROUGE(exclude_indices={0}) def f1(self, r, p): if r == p == 0: return 0 return 2 * r * p / (r + p) @multi_device def test_rouge(self, device: str): self.metric.reset() predictions = torch.tensor([[1, 0, 1, 2], [1, 0, 3, 0], [1, 2, 3, 0]], device=device) targets = torch.tensor([[2, 0, 1, 2], [1, 2, 1, 0], [1, 0, 2, 3]], device=device) self.metric(predictions, targets) metrics = self.metric.get_metric() assert self.metric._total_sequence_count == 3 # ROUGE-N # Unigram unigram_recall = self.metric._total_rouge_n_recalls[1] assert unigram_recall == 2 / 3 + 1 / 3 + 3 / 3 unigram_precision = self.metric._total_rouge_n_precisions[1] assert unigram_precision == 2 / 3 + 1 / 2 + 3 / 3 unigram_f1 = self.metric._total_rouge_n_f1s[1] assert unigram_f1 == self.f1(2 / 3, 2 / 3) + self.f1( 1 / 2, 1 / 3) + self.f1(3 / 3, 3 / 3) assert metrics[ "ROUGE-1_R"] == unigram_recall / self.metric._total_sequence_count assert metrics[ "ROUGE-1_P"] == unigram_precision / self.metric._total_sequence_count assert metrics[ "ROUGE-1_F1"] == unigram_f1 / self.metric._total_sequence_count # Bigram bigram_recall = self.metric._total_rouge_n_recalls[2] assert bigram_recall == 1 / 1 + 0 / 2 + 1 / 1 bigram_precision = self.metric._total_rouge_n_precisions[2] assert bigram_precision == 1 / 1 + 0 + 1 / 2 bigram_f1 = self.metric._total_rouge_n_f1s[2] assert bigram_f1 == self.f1(1 / 1, 1 / 1) + self.f1( 0, 0 / 2) + self.f1(1 / 2, 1 / 1) assert metrics[ "ROUGE-2_R"] == bigram_recall / self.metric._total_sequence_count assert metrics[ "ROUGE-2_P"] == bigram_precision / self.metric._total_sequence_count assert metrics[ "ROUGE-2_F1"] == bigram_f1 / self.metric._total_sequence_count # ROUGE-L assert self.metric._total_rouge_l_f1 == self.f1( 2 / 3, 2 / 3) + self.f1(1 / 3, 1 / 2) + self.f1(3 / 3, 3 / 3) assert (metrics["ROUGE-L"] == self.metric._total_rouge_l_f1 / self.metric._total_sequence_count) def test_rouge_with_zero_counts(self): self.metric.reset() metrics = self.metric.get_metric() for score in metrics.values(): assert score == 0.0
class Bart(Model): """ BART model from the paper "BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension" (https://arxiv.org/abs/1910.13461). The Bart model here uses a language modeling head and thus can be used for text generation. # Parameters model_name : `str`, required Name of the pre-trained BART model to use. Available options can be found in `transformers.models.bart.modeling_bart.BART_PRETRAINED_MODEL_ARCHIVE_MAP`. vocab : `Vocabulary`, required Vocabulary containing source and target vocabularies. beam_search : `Lazy[BeamSearch]`, optional (default = `Lazy(BeamSearch)`) This is used to during inference to select the tokens of the decoded output sequence. indexer : `PretrainedTransformerIndexer`, optional (default = `None`) Indexer to be used for converting decoded sequences of ids to to sequences of tokens. encoder : `Seq2SeqEncoder`, optional (default = `None`) Encoder to used in BART. By default, the original BART encoder is used. """ def __init__( self, model_name: str, vocab: Vocabulary, beam_search: Lazy[BeamSearch] = Lazy(BeamSearch), indexer: PretrainedTransformerIndexer = None, encoder: Seq2SeqEncoder = None, **kwargs, ): super().__init__(vocab) self.bart = BartForConditionalGeneration.from_pretrained(model_name) self._indexer = indexer or PretrainedTransformerIndexer(model_name, namespace="tokens") self._start_id = self.bart.config.bos_token_id # CLS self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id self._end_id = self.bart.config.eos_token_id # SEP self._pad_id = self.bart.config.pad_token_id # PAD # At prediction time, we'll use a beam search to find the best target sequence. # For backwards compatibility, check if beam_size or max_decoding_steps were passed in as # kwargs. If so, update the BeamSearch object before constructing and raise a DeprecationWarning deprecation_warning = ( "The parameter {} has been deprecated." " Provide this parameter as argument to beam_search instead." ) beam_search_extras = {} if "beam_size" in kwargs: beam_search_extras["beam_size"] = kwargs["beam_size"] warnings.warn(deprecation_warning.format("beam_size"), DeprecationWarning) if "max_decoding_steps" in kwargs: beam_search_extras["max_steps"] = kwargs["max_decoding_steps"] warnings.warn(deprecation_warning.format("max_decoding_steps"), DeprecationWarning) self._beam_search = beam_search.construct( end_index=self._end_id, vocab=self.vocab, **beam_search_extras ) self._rouge = ROUGE(exclude_indices={self._start_id, self._pad_id, self._end_id}) self._bleu = BLEU(exclude_indices={self._start_id, self._pad_id, self._end_id}) # Replace bart encoder with given encoder. We need to extract the two embedding layers so that # we can use them in the encoder wrapper if encoder is not None: assert ( encoder.get_input_dim() == encoder.get_output_dim() == self.bart.config.hidden_size ) self.bart.model.encoder = _BartEncoderWrapper( encoder, self.bart.model.encoder.embed_tokens, self.bart.model.encoder.embed_positions, ) def forward( self, source_tokens: TextFieldTensors, target_tokens: TextFieldTensors = None ) -> Dict[str, torch.Tensor]: """ Performs the forward step of Bart. # Parameters source_tokens : `TextFieldTensors`, required The source tokens for the encoder. We assume they are stored under the `tokens` key. target_tokens : `TextFieldTensors`, optional (default = `None`) The target tokens for the decoder. We assume they are stored under the `tokens` key. If no target tokens are given, the source tokens are shifted to the right by 1. # Returns `Dict[str, torch.Tensor]` During training, this dictionary contains the `decoder_logits` of shape `(batch_size, max_target_length, target_vocab_size)` and the `loss`. During inference, it contains `predictions` of shape `(batch_size, max_decoding_steps)` and `log_probabilities` of shape `(batch_size,)`. """ inputs = source_tokens targets = target_tokens input_ids, input_mask = inputs["tokens"]["token_ids"], inputs["tokens"]["mask"] outputs = {} # If no targets are provided, then shift input to right by 1. Bart already does this internally # but it does not use them for loss calculation. if targets is not None: target_ids, target_mask = targets["tokens"]["token_ids"], targets["tokens"]["mask"] else: target_ids = input_ids[:, 1:] target_mask = input_mask[:, 1:] if self.training: bart_outputs = self.bart( input_ids=input_ids, attention_mask=input_mask, decoder_input_ids=target_ids[:, :-1].contiguous(), decoder_attention_mask=target_mask[:, :-1].contiguous(), use_cache=False, return_dict=True, ) outputs["decoder_logits"] = bart_outputs.logits # The BART paper mentions label smoothing of 0.1 for sequence generation tasks outputs["loss"] = sequence_cross_entropy_with_logits( bart_outputs.logits, cast(torch.LongTensor, target_ids[:, 1:].contiguous()), cast(torch.BoolTensor, target_mask[:, 1:].contiguous()), label_smoothing=0.1, average="token", ) else: # Use decoder start id and start of sentence to start decoder initial_decoder_ids = torch.tensor( [[self._decoder_start_id]], dtype=input_ids.dtype, device=input_ids.device, ).repeat(input_ids.shape[0], 1) inital_state = { "input_ids": input_ids, "input_mask": input_mask, } beam_result = self._beam_search.search( initial_decoder_ids, inital_state, self.take_step ) predictions = beam_result[0] max_pred_indices = ( beam_result[1].argmax(dim=-1).view(-1, 1, 1).expand(-1, -1, predictions.shape[-1]) ) predictions = predictions.gather(dim=1, index=max_pred_indices).squeeze(dim=1) self._rouge(predictions, target_ids) self._bleu(predictions, target_ids) outputs["predictions"] = predictions outputs["log_probabilities"] = ( beam_result[1].gather(dim=-1, index=max_pred_indices[..., 0]).squeeze(dim=-1) ) self.make_output_human_readable(outputs) return outputs @staticmethod def _decoder_cache_to_dict(decoder_cache: DecoderCacheType) -> Dict[str, torch.Tensor]: cache_dict = {} for layer_index, layer_cache in enumerate(decoder_cache): # Each layer caches the key and value tensors for its self-attention and cross-attention. # Hence the `layer_cache` tuple has 4 elements. assert len(layer_cache) == 4 for tensor_index, tensor in enumerate(layer_cache): key = f"decoder_cache_{layer_index}_{tensor_index}" cache_dict[key] = tensor return cache_dict def _dict_to_decoder_cache(self, cache_dict: Dict[str, torch.Tensor]) -> DecoderCacheType: decoder_cache = [] for layer_index in range(len(self.bart.model.decoder.layers)): base_key = f"decoder_cache_{layer_index}_" layer_cache = ( cache_dict[base_key + "0"], cache_dict[base_key + "1"], cache_dict[base_key + "2"], cache_dict[base_key + "3"], ) decoder_cache.append(layer_cache) assert decoder_cache return tuple(decoder_cache) def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], step: int ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take step during beam search. # Parameters last_predictions : `torch.Tensor` The predicted token ids from the previous step. Shape: `(group_size,)` state : `Dict[str, torch.Tensor]` State required to generate next set of predictions step : `int` The time step in beam search decoding. # Returns `Tuple[torch.Tensor, Dict[str, torch.Tensor]]` A tuple containing logits for the next tokens of shape `(group_size, target_vocab_size)` and an updated state dictionary. """ if len(last_predictions.shape) == 1: last_predictions = last_predictions.unsqueeze(-1) decoder_cache = None decoder_cache_dict = { k: state[k].contiguous() for k in state if k not in {"input_ids", "input_mask", "encoder_states"} } if len(decoder_cache_dict) != 0: decoder_cache = self._dict_to_decoder_cache(decoder_cache_dict) encoder_outputs = (state["encoder_states"],) if "encoder_states" in state else None outputs = self.bart( input_ids=state["input_ids"] if encoder_outputs is None else None, attention_mask=state["input_mask"], encoder_outputs=encoder_outputs, decoder_input_ids=last_predictions, past_key_values=decoder_cache, use_cache=True, return_dict=True, ) logits = outputs.logits[:, -1, :] log_probabilities = F.log_softmax(logits, dim=-1) decoder_cache = outputs.past_key_values if decoder_cache is not None: decoder_cache_dict = self._decoder_cache_to_dict(decoder_cache) state.update(decoder_cache_dict) state["encoder_states"] = outputs.encoder_last_hidden_state return log_probabilities, state def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: """ # Parameters output_dict : `Dict[str, torch.Tensor]` A dictionary containing a batch of predictions with key `predictions`. The tensor should have shape `(batch_size, max_sequence_length)` # Returns `Dict[str, Any]` Original `output_dict` with an additional `predicted_tokens` key that maps to a list of lists of tokens. """ predictions = output_dict["predictions"] predicted_tokens = [None] * predictions.shape[0] for i in range(predictions.shape[0]): predicted_tokens[i] = self._indexer.indices_to_tokens( {"token_ids": predictions[i].tolist()}, self.vocab, ) output_dict["predicted_tokens"] = predicted_tokens # type: ignore output_dict["predicted_text"] = self._indexer._tokenizer.batch_decode( predictions.tolist(), skip_special_tokens=True ) return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics: Dict[str, float] = {} if not self.training: metrics.update(self._rouge.get_metric(reset=reset)) metrics.update(self._bleu.get_metric(reset=reset)) return metrics default_predictor = "seq2seq"