def setUp(self): super().setUp() self.metric = BLEU(ngram_weights=(0.5, 0.5), exclude_indices={0})
class VAE(Model): """ This ``VAE`` class is a :class:`Model` which implements a simple VAE as first described in https://arxiv.org/pdf/1511.06349.pdf (Bowman et al., 2015). Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (`tokens`) or the target tokens can have a different namespace, in which case it needs to be specified as `target_namespace`. variational_encoder : ``VariationalEncoder``, required The encoder model of which to pass the source tokens decoder : ``Model``, required The variational decoder model of which to pass the the latent variable latent_dim : ``int``, required The dimention of the latent, z vector. This is not necessarily the same size as the encoder output dim initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. """ def __init__( self, vocab: Vocabulary, variational_encoder: VariationalEncoder, decoder: Decoder, kl_weight: LossWeight, temperature: float = 1.0, initializer: InitializerApplicator = InitializerApplicator() ) -> None: super(VAE, self).__init__(vocab) self._encoder = variational_encoder self._decoder = decoder self._latent_dim = variational_encoder.latent_dim self._encoder_output_dim = self._encoder.get_encoder_output_dim() self._start_index = self.vocab.get_token_index(START_SYMBOL) self._end_index = self.vocab.get_token_index(END_SYMBOL) self._pad_index = self.vocab.get_token_index(self.vocab._padding_token) # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={ self._pad_index, self._end_index, self._start_index }) self._kl_metric = Average() self.kl_weight = kl_weight self._temperature = temperature initializer(self) @overrides def forward( self, # type: ignore source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Make forward pass for both training/validation/test time. """ encoder_outs = self._encoder(source_tokens) p_z = encoder_outs['prior'] q_z = encoder_outs['posterior'] kl_weight = self.kl_weight.get() if self.training: z = q_z.rsample() self.kl_weight.step() else: z = self._encoder.reparametrize(p_z, q_z, self._temperature) batch_size = z.size(0) kld = kl_divergence(q_z, p_z).sum() / batch_size self._kl_metric(kld) output_dict = {'z': z, 'predictions': source_tokens['tokens']} if not target_tokens: return output_dict # Do Decoding output_dict.update(self._decoder(z, target_tokens)) rec_loss = output_dict['loss'] kl_loss = kld * kl_weight output_dict['loss'] = rec_loss + kl_loss if not self.training: best_predictions = output_dict["predictions"] self._bleu(best_predictions, target_tokens["tokens"]) return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu and not self.training: all_metrics.update(self._bleu.get_metric(reset=reset)) all_metrics.update({'klw': float(self.kl_weight.get())}) all_metrics.update( {'kl': float(self._kl_metric.get_metric(reset=reset))}) return all_metrics def generate(self, num_to_sample: int = 1): cuda_device = self._get_prediction_device() prior_mean = nn_util.move_to_device( torch.zeros((num_to_sample, self._latent_dim)), cuda_device) prior_stddev = torch.ones_like(prior_mean) prior = Normal(prior_mean, prior_stddev) latent = prior.sample() generated = self._decoder.generate(latent) return self.decode(generated) @overrides # simple_seq2seq's decode def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index(x) for x in indices ] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict
class BleuTest(AllenNlpTestCase): def setUp(self): super().setUp() self.metric = BLEU(ngram_weights=(0.5, 0.5), exclude_indices={0}) def test_get_valid_tokens_mask(self): tensor = torch.tensor([[1, 2, 3, 0], [0, 1, 1, 0]]) result = self.metric._get_valid_tokens_mask(tensor) result = result.long().numpy() check = np.array([[1, 1, 1, 0], [0, 1, 1, 0]]) np.testing.assert_array_equal(result, check) def test_ngrams(self): tensor = torch.tensor([1, 2, 3, 1, 2, 0]) # Unigrams. counts = Counter(self.metric._ngrams(tensor, 1)) unigram_check = {(1,): 2, (2,): 2, (3,): 1} assert counts == unigram_check # Bigrams. counts = Counter(self.metric._ngrams(tensor, 2)) bigram_check = {(1, 2): 2, (2, 3): 1, (3, 1): 1} assert counts == bigram_check # Trigrams. counts = Counter(self.metric._ngrams(tensor, 3)) trigram_check = {(1, 2, 3): 1, (2, 3, 1): 1, (3, 1, 2): 1} assert counts == trigram_check # ngram size too big, no ngrams produced. counts = Counter(self.metric._ngrams(tensor, 7)) assert counts == {} def test_bleu_computed_correctly(self): self.metric.reset() # shape: (batch_size, max_sequence_length) predictions = torch.tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]) # shape: (batch_size, max_gold_sequence_length) gold_targets = torch.tensor([[2, 0, 0], [1, 0, 0], [1, 1, 2]]) self.metric(predictions, gold_targets) assert self.metric._prediction_lengths == 6 assert self.metric._reference_lengths == 5 # Number of unigrams in predicted sentences that match gold sentences # (but not more than maximum occurence of gold unigram within batch). assert self.metric._precision_matches[1] == ( 0 + # no matches in first sentence. 1 + # one clipped match in second sentence. 2 # two clipped matches in third sentence. ) # Total number of predicted unigrams. assert self.metric._precision_totals[1] == ( 1 + 2 + 3 ) # Number of bigrams in predicted sentences that match gold sentences # (but not more than maximum occurence of gold bigram within batch). assert self.metric._precision_matches[2] == ( 0 + 0 + 1 ) # Total number of predicted bigrams. assert self.metric._precision_totals[2] == ( 0 + 1 + 2 ) # Brevity penalty should be 1.0 assert self.metric._get_brevity_penalty() == 1.0 bleu = self.metric.get_metric(reset=True)["BLEU"] check = math.exp(0.5 * (math.log(3) - math.log(6)) + 0.5 * (math.log(1) - math.log(3))) np.testing.assert_approx_equal(bleu, check) def test_bleu_computed_with_zero_counts(self): self.metric.reset() assert self.metric.get_metric()["BLEU"] == 0
def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, max_decoding_steps: int, attention: Attention, schema_path: str = None, missing_alignment_int: int = 0, indexfield_padding_index: int = -1, beam_size: int = None, target_namespace: str = "tokens", target_embedding_dim: int = None, scheduled_sampling_ratio: float = 0., use_bleu: bool = True, emb_dropout: float = 0.0, dec_dropout: float = 0.0, attn_loss_lambda: float = 0.5, token_based_metric: Metric = None) -> None: super(AttnSupSeq2Seq, self).__init__(vocab) self._target_namespace = target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio self._indexfield_padding_index = indexfield_padding_index self._missing_alignment_int = missing_alignment_int # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={ pad_index, self._end_index, self._start_index }) else: self._bleu = None if token_based_metric: self._token_based_metric = token_based_metric else: self._token_based_metric = TokenSequenceAccuracy() # log attention supervision CE loss as a metric self._attn_sup_loss = Average() self._sql_metrics = schema_path is not None if self._sql_metrics: # SQL specific metrics: match between the templates free of schema constants, # and match between the schema constants self._schema_free_match = GlobalTemplAccuracy( schema_path=schema_path) self._kb_match = KnowledgeBaseConstsAccuracy( schema_path=schema_path) # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source vocab tokens. self._source_embedder = source_embedder self._emb_dropout = Dropout(p=emb_dropout) self._dec_dropout = Dropout(p=dec_dropout) self._attn_loss_lambda = attn_loss_lambda # Encodes the sequence of source embeddings into a sequence of hidden states. self._encoder = encoder num_classes = self.vocab.get_vocab_size(self._target_namespace) # Attention mechanism applied to the encoder output for each step. self._attention = attention self._attention._normalize = False # Dense embedding of vocab words in the target space. target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim( ) self._target_embedder = Embedding(num_classes, target_embedding_dim) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = self._encoder.get_output_dim() self._decoder_output_dim = self._encoder_output_dim # A weighted average over encoder outputs will be concatenated to the previous target embedding # to form the input to the decoder at each time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. # TODO (pradeep): Do not hardcode decoder cell type. self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes)
class Bart(Model): """ BART model from the paper "BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension" (https://arxiv.org/abs/1910.13461). The Bart model here uses a language modeling head and thus can be used for text generation. # Parameters model_name : `str`, required Name of the pre-trained BART model to use. Available options can be found in `transformers.models.bart.modeling_bart.BART_PRETRAINED_MODEL_ARCHIVE_MAP`. vocab : `Vocabulary`, required Vocabulary containing source and target vocabularies. beam_search : `Lazy[BeamSearch]`, optional (default = `Lazy(BeamSearch)`) This is used to during inference to select the tokens of the decoded output sequence. indexer : `PretrainedTransformerIndexer`, optional (default = `None`) Indexer to be used for converting decoded sequences of ids to to sequences of tokens. encoder : `Seq2SeqEncoder`, optional (default = `None`) Encoder to used in BART. By default, the original BART encoder is used. """ def __init__( self, model_name: str, vocab: Vocabulary, beam_search: Lazy[BeamSearch] = Lazy(BeamSearch), indexer: PretrainedTransformerIndexer = None, encoder: Seq2SeqEncoder = None, **kwargs, ): super().__init__(vocab) self.bart = BartForConditionalGeneration.from_pretrained(model_name) self._indexer = indexer or PretrainedTransformerIndexer(model_name, namespace="tokens") self._start_id = self.bart.config.bos_token_id # CLS self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id self._end_id = self.bart.config.eos_token_id # SEP self._pad_id = self.bart.config.pad_token_id # PAD # At prediction time, we'll use a beam search to find the best target sequence. # For backwards compatibility, check if beam_size or max_decoding_steps were passed in as # kwargs. If so, update the BeamSearch object before constructing and raise a DeprecationWarning deprecation_warning = ( "The parameter {} has been deprecated." " Provide this parameter as argument to beam_search instead." ) beam_search_extras = {} if "beam_size" in kwargs: beam_search_extras["beam_size"] = kwargs["beam_size"] warnings.warn(deprecation_warning.format("beam_size"), DeprecationWarning) if "max_decoding_steps" in kwargs: beam_search_extras["max_steps"] = kwargs["max_decoding_steps"] warnings.warn(deprecation_warning.format("max_decoding_steps"), DeprecationWarning) self._beam_search = beam_search.construct( end_index=self._end_id, vocab=self.vocab, **beam_search_extras ) self._rouge = ROUGE(exclude_indices={self._start_id, self._pad_id, self._end_id}) self._bleu = BLEU(exclude_indices={self._start_id, self._pad_id, self._end_id}) # Replace bart encoder with given encoder. We need to extract the two embedding layers so that # we can use them in the encoder wrapper if encoder is not None: assert ( encoder.get_input_dim() == encoder.get_output_dim() == self.bart.config.hidden_size ) self.bart.model.encoder = _BartEncoderWrapper( encoder, self.bart.model.encoder.embed_tokens, self.bart.model.encoder.embed_positions, ) def forward( self, source_tokens: TextFieldTensors, target_tokens: TextFieldTensors = None ) -> Dict[str, torch.Tensor]: """ Performs the forward step of Bart. # Parameters source_tokens : `TextFieldTensors`, required The source tokens for the encoder. We assume they are stored under the `tokens` key. target_tokens : `TextFieldTensors`, optional (default = `None`) The target tokens for the decoder. We assume they are stored under the `tokens` key. If no target tokens are given, the source tokens are shifted to the right by 1. # Returns `Dict[str, torch.Tensor]` During training, this dictionary contains the `decoder_logits` of shape `(batch_size, max_target_length, target_vocab_size)` and the `loss`. During inference, it contains `predictions` of shape `(batch_size, max_decoding_steps)` and `log_probabilities` of shape `(batch_size,)`. """ inputs = source_tokens targets = target_tokens input_ids, input_mask = inputs["tokens"]["token_ids"], inputs["tokens"]["mask"] outputs = {} # If no targets are provided, then shift input to right by 1. Bart already does this internally # but it does not use them for loss calculation. if targets is not None: target_ids, target_mask = targets["tokens"]["token_ids"], targets["tokens"]["mask"] else: target_ids = input_ids[:, 1:] target_mask = input_mask[:, 1:] if self.training: bart_outputs = self.bart( input_ids=input_ids, attention_mask=input_mask, decoder_input_ids=target_ids[:, :-1].contiguous(), decoder_attention_mask=target_mask[:, :-1].contiguous(), use_cache=False, return_dict=True, ) outputs["decoder_logits"] = bart_outputs.logits # The BART paper mentions label smoothing of 0.1 for sequence generation tasks outputs["loss"] = sequence_cross_entropy_with_logits( bart_outputs.logits, cast(torch.LongTensor, target_ids[:, 1:].contiguous()), cast(torch.BoolTensor, target_mask[:, 1:].contiguous()), label_smoothing=0.1, average="token", ) else: # Use decoder start id and start of sentence to start decoder initial_decoder_ids = torch.tensor( [[self._decoder_start_id]], dtype=input_ids.dtype, device=input_ids.device, ).repeat(input_ids.shape[0], 1) inital_state = { "input_ids": input_ids, "input_mask": input_mask, } beam_result = self._beam_search.search( initial_decoder_ids, inital_state, self.take_step ) predictions = beam_result[0] max_pred_indices = ( beam_result[1].argmax(dim=-1).view(-1, 1, 1).expand(-1, -1, predictions.shape[-1]) ) predictions = predictions.gather(dim=1, index=max_pred_indices).squeeze(dim=1) self._rouge(predictions, target_ids) self._bleu(predictions, target_ids) outputs["predictions"] = predictions outputs["log_probabilities"] = ( beam_result[1].gather(dim=-1, index=max_pred_indices[..., 0]).squeeze(dim=-1) ) self.make_output_human_readable(outputs) return outputs @staticmethod def _decoder_cache_to_dict(decoder_cache: DecoderCacheType) -> Dict[str, torch.Tensor]: cache_dict = {} for layer_index, layer_cache in enumerate(decoder_cache): # Each layer caches the key and value tensors for its self-attention and cross-attention. # Hence the `layer_cache` tuple has 4 elements. assert len(layer_cache) == 4 for tensor_index, tensor in enumerate(layer_cache): key = f"decoder_cache_{layer_index}_{tensor_index}" cache_dict[key] = tensor return cache_dict def _dict_to_decoder_cache(self, cache_dict: Dict[str, torch.Tensor]) -> DecoderCacheType: decoder_cache = [] for layer_index in range(len(self.bart.model.decoder.layers)): base_key = f"decoder_cache_{layer_index}_" layer_cache = ( cache_dict[base_key + "0"], cache_dict[base_key + "1"], cache_dict[base_key + "2"], cache_dict[base_key + "3"], ) decoder_cache.append(layer_cache) assert decoder_cache return tuple(decoder_cache) def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], step: int ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take step during beam search. # Parameters last_predictions : `torch.Tensor` The predicted token ids from the previous step. Shape: `(group_size,)` state : `Dict[str, torch.Tensor]` State required to generate next set of predictions step : `int` The time step in beam search decoding. # Returns `Tuple[torch.Tensor, Dict[str, torch.Tensor]]` A tuple containing logits for the next tokens of shape `(group_size, target_vocab_size)` and an updated state dictionary. """ if len(last_predictions.shape) == 1: last_predictions = last_predictions.unsqueeze(-1) decoder_cache = None decoder_cache_dict = { k: state[k].contiguous() for k in state if k not in {"input_ids", "input_mask", "encoder_states"} } if len(decoder_cache_dict) != 0: decoder_cache = self._dict_to_decoder_cache(decoder_cache_dict) encoder_outputs = (state["encoder_states"],) if "encoder_states" in state else None outputs = self.bart( input_ids=state["input_ids"] if encoder_outputs is None else None, attention_mask=state["input_mask"], encoder_outputs=encoder_outputs, decoder_input_ids=last_predictions, past_key_values=decoder_cache, use_cache=True, return_dict=True, ) logits = outputs.logits[:, -1, :] log_probabilities = F.log_softmax(logits, dim=-1) decoder_cache = outputs.past_key_values if decoder_cache is not None: decoder_cache_dict = self._decoder_cache_to_dict(decoder_cache) state.update(decoder_cache_dict) state["encoder_states"] = outputs.encoder_last_hidden_state return log_probabilities, state def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: """ # Parameters output_dict : `Dict[str, torch.Tensor]` A dictionary containing a batch of predictions with key `predictions`. The tensor should have shape `(batch_size, max_sequence_length)` # Returns `Dict[str, Any]` Original `output_dict` with an additional `predicted_tokens` key that maps to a list of lists of tokens. """ predictions = output_dict["predictions"] predicted_tokens = [None] * predictions.shape[0] for i in range(predictions.shape[0]): predicted_tokens[i] = self._indexer.indices_to_tokens( {"token_ids": predictions[i].tolist()}, self.vocab, ) output_dict["predicted_tokens"] = predicted_tokens # type: ignore output_dict["predicted_text"] = self._indexer._tokenizer.batch_decode( predictions.tolist(), skip_special_tokens=True ) return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics: Dict[str, float] = {} if not self.training: metrics.update(self._rouge.get_metric(reset=reset)) metrics.update(self._bleu.get_metric(reset=reset)) return metrics default_predictor = "seq2seq"
class MaskedCopyNet(Model): def __init__(self, vocab: Vocabulary, embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, max_decoding_steps: int, attention: Attention = None, mask_embedder: TextFieldEmbedder = None, mask_attention: Attention = None, beam_size: int = None, target_namespace: str = "tokens", scheduled_sampling_ratio: float = 0., use_bleu: bool = True) -> None: super().__init__(vocab) self._target_namespace = target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) self._bleu = BLEU(exclude_indices={pad_index, self._end_index, self._start_index}) else: self._bleu = None # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source vocab tokens. self._embedder = embedder self._mask_embedder = mask_embedder # Encodes the sequence of source embeddings into a sequence of hidden states. self._encoder = encoder num_classes = self.vocab.get_vocab_size(self._target_namespace) # Attention mechanism applied to the encoder output for each step. self._attention = attention self._mask_attention = mask_attention # Dense embedding of vocab words in the target space. target_embedding_dim = self._embedder.get_output_dim() # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = self._encoder.get_output_dim() self._decoder_output_dim = self._encoder_output_dim if self._attention: # If using attention, a weighted average over encoder outputs will be concatenated # to the previous target embedding to form the input to the decoder at each # time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim else: # Otherwise, the input to the decoder is just the previous target embedding. self._decoder_input_dim = target_embedding_dim if self._mask_attention: self._decoder_input_dim += self._mask_embedder.get_output_dim() # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) def take_step(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections(last_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def forward(self, # type: ignore source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None, mask_tokens: Dict[str, torch.LongTensor] = None, **kwargs) -> Dict[str, torch.Tensor]: del kwargs assert mask_tokens is not None or self._mask_embedder is None, \ 'You must pass `mask_tokens` when `mask_embedder` is not None' state = self.encode(source_tokens, mask_tokens) if target_tokens: state = self.init_decoder_state(state) output_dict = self._forward_loop(state, target_tokens) else: output_dict = {} if not self.training: state = self.init_decoder_state(state) predictions = self.beam_search(state) output_dict.update(predictions) if target_tokens and self._bleu: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._bleu(best_predictions, target_tokens["tokens"]) return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for i, indices in enumerate(predicted_indices): curr_predictions = [] for ind in indices: ind = list(ind) # Collect indices till the first end_symbol if self._end_index in ind: ind = ind[:ind.index(self._end_index)] predicted_tokens = [self.vocab.get_token_from_index(x, namespace=self._target_namespace) for x in ind] curr_predictions.append(predicted_tokens) all_predicted_tokens.append(curr_predictions) output_dict["predicted_tokens"] = all_predicted_tokens # [batch_size, k, num_decoding_steps] return output_dict def encode(self, source_tokens: Dict[str, torch.Tensor], mask_tokens: Dict[str, torch.Tensor] = None) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_input = self._embedder(source_tokens) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = self._encoder(embedded_input, source_mask) state = { "source_mask": source_mask, "encoder_outputs": encoder_outputs } if mask_tokens is not None and self._mask_embedder is not None: embedded_input = self._mask_embedder(mask_tokens) masker_mask = util.get_text_field_mask(mask_tokens) state.update( { "mask_source_mask": masker_mask, "mask_encoder_outputs": embedded_input } ) return state def init_decoder_state(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( state["encoder_outputs"], state["source_mask"], self._encoder.is_bidirectional()) # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = final_encoder_output # shape: (batch_size, decoder_output_dim) state["decoder_context"] = state["encoder_outputs"].new_zeros(batch_size, self._decoder_output_dim) return state def _forward_loop(self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size()[0] if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full((batch_size,), fill_value=self._start_index) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size,) input_choices = last_predictions elif not target_tokens: # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] # shape: (batch_size, num_classes) output_projections, state = self._prepare_output_projections(input_choices, state) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape: (batch_size, num_classes) class_probabilities = F.softmax(output_projections, dim=-1) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(class_probabilities, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) output_dict = {"predictions": predictions} if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss return output_dict def beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full((batch_size,), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _prepare_output_projections(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (group_size, decoder_output_dim) decoder_context = state["decoder_context"] # shape: (group_size, target_embedding_dim) embedded_input = self._embedder({self._target_namespace: last_predictions}) if self._attention: # shape: (group_size, encoder_output_dim) attended_input = self._prepare_attended_input(decoder_hidden, encoder_outputs, source_mask) # shape: (group_size, decoder_output_dim + target_embedding_dim) decoder_input = torch.cat((attended_input, embedded_input), -1) else: # shape: (group_size, target_embedding_dim) decoder_input = embedded_input if self._mask_attention and self._mask_embedder: mask_encoder_outputs = state["mask_encoder_outputs"] mask_source_mask = state["mask_source_mask"] mask_attended_input = self._prepare_mask_attended_input( decoder_hidden, mask_encoder_outputs, mask_source_mask ) decoder_input = torch.cat((decoder_input, mask_attended_input), -1) # shape (decoder_hidden): (batch_size, decoder_output_dim) # shape (decoder_context): (batch_size, decoder_output_dim) decoder_hidden, decoder_context = self._decoder_cell( decoder_input, (decoder_hidden, decoder_context)) state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context # shape: (group_size, num_classes) output_projections = self._output_projection_layer(decoder_hidden) return output_projections, state def _prepare_attended_input(self, decoder_hidden_state: torch.LongTensor = None, encoder_outputs: torch.LongTensor = None, encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor: encoder_outputs_mask = encoder_outputs_mask.float() input_weights = self._attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask) attended_input = util.weighted_sum(encoder_outputs, input_weights) return attended_input def _prepare_mask_attended_input(self, decoder_hidden_state: torch.LongTensor = None, mask_encoder_outputs: torch.LongTensor = None, mask_encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor: encoder_outputs_mask = mask_encoder_outputs_mask.float() input_weights = self._mask_attention(decoder_hidden_state, mask_encoder_outputs, encoder_outputs_mask) attended_input = util.weighted_sum(mask_encoder_outputs, input_weights) return attended_input @staticmethod def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.Tensor: # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu and not self.training: all_metrics.update(self._bleu.get_metric(reset=reset)) return all_metrics
class AttnSupSeq2Seq(Model): """ Adaptation of the ``SimpleSeq2Seq`` class in allennlp_models, with auxiliary attention-supervision loss Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (`tokens`) or the target tokens can have a different namespace, in which case it needs to be specified as `target_namespace`. source_embedder : ``TextFieldEmbedder``, required Embedder for source side sequences encoder : ``Seq2SeqEncoder``, required The encoder of the "encoder/decoder" model max_decoding_steps : ``int`` Maximum length of decoded sequences. target_namespace : ``str``, optional (default = 'target_tokens') If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. target_embedding_dim : ``int``, optional (default = source_embedding_dim) You can specify an embedding dimensionality for the target side. If not, we'll use the same value as the source embedder's. attention : ``Attention``, optional (default = None) If you want to use attention to get a dynamic summary of the encoder outputs at each step of decoding, this is the function used to compute similarity between the decoder hidden state and encoder outputs. attention_function: ``SimilarityFunction``, optional (default = None) This is if you want to use the legacy implementation of attention. This will be deprecated since it consumes more memory than the specialized attention modules. beam_size : ``int``, optional (default = None) Width of the beam for beam search. If not specified, greedy decoding is used. scheduled_sampling_ratio : ``float``, optional (default = 0.) At each timestep during training, we sample a random number between 0 and 1, and if it is not less than this value, we use the ground truth labels for the whole batch. Else, we use the predictions from the previous time step for the whole batch. If this value is 0.0 (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not using target side ground truth labels. See the following paper for more information: `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al., 2015 <https://arxiv.org/abs/1506.03099>`_. use_bleu : ``bool``, optional (default = True) If True, the BLEU metric will be calculated during validation. """ def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, max_decoding_steps: int, attention: Attention, schema_path: str = None, missing_alignment_int: int = 0, indexfield_padding_index: int = -1, beam_size: int = None, target_namespace: str = "tokens", target_embedding_dim: int = None, scheduled_sampling_ratio: float = 0., use_bleu: bool = True, emb_dropout: float = 0.0, dec_dropout: float = 0.0, attn_loss_lambda: float = 0.5, token_based_metric: Metric = None) -> None: super(AttnSupSeq2Seq, self).__init__(vocab) self._target_namespace = target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio self._indexfield_padding_index = indexfield_padding_index self._missing_alignment_int = missing_alignment_int # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={ pad_index, self._end_index, self._start_index }) else: self._bleu = None if token_based_metric: self._token_based_metric = token_based_metric else: self._token_based_metric = TokenSequenceAccuracy() # log attention supervision CE loss as a metric self._attn_sup_loss = Average() self._sql_metrics = schema_path is not None if self._sql_metrics: # SQL specific metrics: match between the templates free of schema constants, # and match between the schema constants self._schema_free_match = GlobalTemplAccuracy( schema_path=schema_path) self._kb_match = KnowledgeBaseConstsAccuracy( schema_path=schema_path) # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source vocab tokens. self._source_embedder = source_embedder self._emb_dropout = Dropout(p=emb_dropout) self._dec_dropout = Dropout(p=dec_dropout) self._attn_loss_lambda = attn_loss_lambda # Encodes the sequence of source embeddings into a sequence of hidden states. self._encoder = encoder num_classes = self.vocab.get_vocab_size(self._target_namespace) # Attention mechanism applied to the encoder output for each step. self._attention = attention self._attention._normalize = False # Dense embedding of vocab words in the target space. target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim( ) self._target_embedder = Embedding(num_classes, target_embedding_dim) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = self._encoder.get_output_dim() self._decoder_output_dim = self._encoder_output_dim # A weighted average over encoder outputs will be concatenated to the previous target embedding # to form the input to the decoder at each time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. # TODO (pradeep): Do not hardcode decoder cell type. self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes) _, output_projections, state = self._prepare_output_projections( last_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def forward_on_instances( self, instances: List[Instance]) -> List[Dict[str, numpy.ndarray]]: """ Takes a list of :class:`~allennlp.data.instance.Instance`s, converts that text into arrays using this model's :class:`Vocabulary`, passes those arrays through :func:`self.forward()` and :func:`self.decode()` (which by default does nothing) and returns the result. Before returning the result, we convert any ``torch.Tensors`` into numpy arrays and separate the batched output into a list of individual dicts per instance. Note that typically this will be faster on a GPU (and conditionally, on a CPU) than repeated calls to :func:`forward_on_instance`. Parameters ---------- instances : List[Instance], required The instances to run the model on. cuda_device : int, required The GPU device to use. -1 means use the CPU. Returns ------- A list of the models output for each instance. """ batch_size = len(instances) with torch.no_grad(): cuda_device = self._get_prediction_device() dataset = Batch(instances) dataset.index_instances(self.vocab) model_input = util.move_to_device(dataset.as_tensor_dict(), cuda_device) outputs = self.decode(self(**model_input)) instance_separated_output: List[Dict[str, numpy.ndarray]] = [ {} for _ in dataset.instances ] for name, output in list(outputs.items()): if isinstance(output, torch.Tensor): # NOTE(markn): This is a hack because 0-dim pytorch tensors are not iterable. # This occurs with batch size 1, because we still want to include the loss in that case. if output.dim() == 0: output = output.unsqueeze(0) if output.size(0) != batch_size: self._maybe_warn_for_unseparable_batches(name) continue output = output.detach().cpu().numpy() elif len(output) != batch_size: self._maybe_warn_for_unseparable_batches(name) continue for instance_output, batch_element in zip( instance_separated_output, output): instance_output[name] = batch_element for instance_output, instance_input in zip( instance_separated_output, instances): for field in instance_input.fields: try: instance_output[field] = instance_input.fields[ field].tokens except Exception as e: continue return instance_separated_output @overrides def forward( self, # type: ignore source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None, alignment_sequence: torch.Tensor = None ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Make foward pass with decoder logic for producing the entire target sequence. Parameters ---------- source_tokens : ``Dict[str, torch.LongTensor]`` The output of `TextField.as_array()` applied on the source `TextField`. This will be passed through a `TextFieldEmbedder` and then through an encoder. target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) Output of `Textfield.as_array()` applied on target `TextField`. We assume that the target tokens are also represented as a `TextField`. alignment_sequence : ``Dict[str, torch.LongTensor]``, optional (default = None) Output of `Textfield.as_array()` applied on alignemnet `TextField`. Returns ------- Dict[str, torch.Tensor] """ state = self._encode(source_tokens) if target_tokens: state = self._init_decoder_state(state) # Remove the trailing dimension (from ListField[ListField[IndexField]]). alignment_sequence = alignment_sequence.squeeze(-1) # The `_forward_loop` decodes the input sequence and computes the loss during training # and validation. output_dict = self._forward_loop(state, target_tokens, alignment_sequence) else: output_dict = {} if not self.training: state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) if target_tokens: if self._bleu: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._bleu(best_predictions, target_tokens["tokens"]) predicted_tokens = self.decode(output_dict)["predicted_tokens"] target_tokens_str = self.decode_target_tokens(target_tokens) if self._token_based_metric: self._token_based_metric(predicted_tokens, target_tokens_str) if self._sql_metrics: self._kb_match(predicted_tokens, target_tokens_str) self._schema_free_match(predicted_tokens, target_tokens_str) # In case of attention coverage mechanism, reset the coverage vector after every batch... try: self._attention.reset_coverage_vector() except Exception: pass return output_dict def decode_target_tokens(self, target_tokens): target_indices = target_tokens['tokens'].detach().cpu().numpy() target_tokens_output = [] for i in range(target_indices.shape[0]): cur_target_indices = target_indices[i] cur_target_indices = list(cur_target_indices) if self._end_index in cur_target_indices: cur_target_indices = cur_target_indices[:cur_target_indices. index(self._end_index)] if self._start_index in cur_target_indices: cur_target_indices = cur_target_indices[ cur_target_indices.index(self._start_index) + 1:] target_tokens_str = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in cur_target_indices ] target_tokens_output.append(target_tokens_str) return target_tokens_output @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _encode( self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_input = self._source_embedder(source_tokens) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = self._encoder(embedded_input, source_mask) encoder_outputs = self._emb_dropout(encoder_outputs) return { "source_mask": source_mask, "encoder_outputs": encoder_outputs, } def _init_decoder_state( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( state["encoder_outputs"], state["source_mask"], self._encoder.is_bidirectional()) # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = final_encoder_output # shape: (batch_size, decoder_output_dim) state["decoder_context"] = state["encoder_outputs"].new_zeros( batch_size, self._decoder_output_dim) return state def _forward_loop( self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None, alignment_sequence: torch.Tensor = None ) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size()[0] if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full((batch_size, ), fill_value=self._start_index) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] step_attn_weights: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand( 1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size,) input_choices = last_predictions elif not target_tokens: # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] # shape: (batch_size, num_classes) # shape: (batch_size, input_max_size) input_weights, output_projections, state = self._prepare_output_projections( input_choices, state) step_attn_weights.append(input_weights.unsqueeze(1)) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape: (batch_size, num_classes) class_probabilities = F.softmax(output_projections, dim=-1) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(class_probabilities, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) # shape: (batch_size, num_decoding_steps, max_input_sequence_length) attention_input_weights = torch.cat(step_attn_weights[:-1], 1) output_dict = { "predictions": predictions, 'attention_input_weights': attention_input_weights } if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # shape: (batch_size, num_decoding_steps, max_input_sequence_length) alignment_mask = self._get_alignment_mask(alignment_sequence) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) attn_sup_loss = self._get_attn_sup_loss(attention_input_weights, alignment_mask, alignment_sequence) self._attn_sup_loss(attn_sup_loss.detach().cpu().item()) output_dict["loss"] = loss + self._attn_loss_lambda * attn_sup_loss return output_dict def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _prepare_output_projections(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]: # pylint: disable=line-too-long """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (group_size, decoder_output_dim) decoder_context = state["decoder_context"] # shape: (group_size, target_embedding_dim) embedded_input = self._target_embedder(last_predictions) # shape: (group_size, encoder_output_dim) attended_input, input_weights = self._prepare_attended_input( decoder_hidden, encoder_outputs, source_mask) # shape: (group_size, decoder_output_dim + target_embedding_dim) decoder_input = torch.cat((attended_input, embedded_input), -1) decoder_input = self._dec_dropout(decoder_input) # shape (decoder_hidden): (batch_size, decoder_output_dim) # shape (decoder_context): (batch_size, decoder_output_dim) decoder_hidden, decoder_context = self._decoder_cell( decoder_input, (decoder_hidden, decoder_context)) state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context # shape: (group_size, num_classes) output_projections = self._output_projection_layer( self._dec_dropout(decoder_hidden)) return input_weights, output_projections, state def _prepare_attended_input( self, decoder_hidden_state: torch.LongTensor = None, encoder_outputs: torch.LongTensor = None, encoder_outputs_mask: torch.LongTensor = None ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply attention over encoder outputs and decoder state.""" # Ensure mask is also a FloatTensor. Or else the multiplication within # attention will complain. # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs_mask = encoder_outputs_mask.float() # shape: (batch_size, max_input_sequence_length) input_logits = self._attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask) # the attention mechanism returns the logits that are necessary for attention supervision loss, # so we normalize it here input_weights = masked_softmax(input_logits, encoder_outputs_mask) # shape: (batch_size, encoder_output_dim) attended_input = util.weighted_sum(encoder_outputs, input_weights) return attended_input, input_logits @staticmethod def _get_attn_sup_loss(attn_weights: torch.Tensor, alignment_mask: torch.Tensor, alignment_sequence: torch.Tensor) -> torch.Tensor: """ Compute the attention supervision CE loss. For each step, take the index of the aligned """ # shape: (batch_size, max_decoding_steps, max_input_seq_length attn_weights = attn_weights.float() alignment_sequence[alignment_sequence == -1] = 0 # for each attn_weights[batch_index, step_index, :] I want to choose the index of # alignment_sequence[batch_index, step_index] return util.sequence_cross_entropy_with_logits(attn_weights, alignment_sequence, alignment_mask) def _get_alignment_mask(self, alignment_sequence): """ The alignment mask includes the target mask + mask on steps that don't have alignment shape: batch_size, max_steps, max_input """ pad_mask = alignment_sequence != self._indexfield_padding_index missing_mask = alignment_sequence != self._missing_alignment_int return pad_mask * missing_mask @staticmethod def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if not self.training: if self._bleu: all_metrics.update(self._bleu.get_metric(reset=reset)) all_metrics.update( self._token_based_metric.get_metric(reset=reset)) if self._sql_metrics: all_metrics.update(self._kb_match.get_metric(reset=reset)) all_metrics.update( self._schema_free_match.get_metric(reset=reset)) all_metrics['attn_sup_loss'] = self._attn_sup_loss.get_metric( reset=reset) return all_metrics
class MyTransformer(Model): def __init__( self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, transformer: Dict, max_decoding_steps: int, target_namespace: str, target_embedder: TextFieldEmbedder = None, use_bleu: bool = True, ) -> None: super().__init__(vocab) self._target_namespace = target_namespace self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) self._pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) if use_bleu: self._bleu = BLEU(exclude_indices={ self._pad_index, self._end_index, self._start_index }) else: self._bleu = None self._seq_acc = SequenceAccuracy() self._max_decoding_steps = max_decoding_steps self._source_embedder = source_embedder self._ndim = transformer["d_model"] self.pos_encoder = PositionalEncoding(self._ndim, transformer["dropout"]) num_classes = self.vocab.get_vocab_size(self._target_namespace) self._transformer = Transformer(**transformer) self._transformer.apply(inplace_relu) if target_embedder is None: self._target_embedder = self._source_embedder else: self._target_embedder = target_embedder self._output_projection_layer = Linear(self._ndim, num_classes) def _get_mask(self, meta_data): mask = torch.zeros(1, len(meta_data), self.vocab.get_vocab_size( self._target_namespace)).float() for bidx, md in enumerate(meta_data): for k, v in self.vocab._token_to_index[ self._target_namespace].items(): if 'position' in k and k not in md['avail_pos']: mask[:, bidx, v] = float('-inf') return mask def generate_square_subsequent_mask(self, sz): mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == False, float('-inf')).masked_fill( mask == True, float(0.0)) return mask @overrides def forward( self, source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None, meta_data: Any = None, ) -> Dict[str, torch.Tensor]: src, src_key_padding_mask = self._encode(self._source_embedder, source_tokens) memory = self._transformer.encoder( src, src_key_padding_mask=src_key_padding_mask) if meta_data is not None: target_vocab_mask = self._get_mask(meta_data) target_vocab_mask = target_vocab_mask.to(memory.device) else: target_vocab_mask = None output_dict = {} targets = None if target_tokens: targets = target_tokens["tokens"][:, 1:] target_mask = (util.get_text_field_mask({"tokens": targets}) == 1) assert targets.size(1) <= self._max_decoding_steps if self.training and target_tokens: tgt, tgt_key_padding_mask = self._encode( self._target_embedder, {"tokens": target_tokens["tokens"][:, :-1]}) tgt_mask = self.generate_square_subsequent_mask(tgt.size(0)).to( memory.device) output = self._transformer.decoder( tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=src_key_padding_mask) logits = self._output_projection_layer(output) if target_vocab_mask is not None: logits += target_vocab_mask class_probabilities = F.softmax(logits.detach(), dim=-1) _, predictions = torch.max(class_probabilities, -1) logits = logits.transpose(0, 1) loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss else: assert self.training is False output_dict["loss"] = torch.tensor(0.0).to(memory.device) if targets is not None: max_target_len = targets.size(1) else: max_target_len = None predictions, class_probabilities = self._decoder_step_by_step( memory, src_key_padding_mask, target_vocab_mask, max_target_len=max_target_len) predictions = predictions.transpose(0, 1) output_dict["predictions"] = predictions output_dict["class_probabilities"] = class_probabilities.transpose( 0, 1) if target_tokens: with torch.no_grad(): best_predictions = output_dict["predictions"] if self._bleu: self._bleu(best_predictions, targets) batch_size = targets.size(0) max_sz = max(best_predictions.size(1), targets.size(1), target_mask.size(1)) best_predictions_ = torch.zeros(batch_size, max_sz).to(memory.device) best_predictions_[:, :best_predictions. size(1)] = best_predictions targets_ = torch.zeros(batch_size, max_sz).to(memory.device) targets_[:, :targets.size(1)] = targets.cpu() target_mask_ = torch.zeros(batch_size, max_sz).to(memory.device) target_mask_[:, :target_mask.size(1)] = target_mask self._seq_acc(best_predictions_.unsqueeze(1), targets_, target_mask_) return output_dict @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): # shape: (batch_size, num_decoding_steps) predicted_indices = predicted_indices.detach().cpu().numpy() # class_probabilities = output_dict["class_probabilities"].detach().cpu() # sample_predicted_indices = [] # for cp in class_probabilities: # sample = torch.multinomial(cp, num_samples=1) # sample_predicted_indices.append(sample) # # shape: (batch_size, num_decoding_steps, num_samples) # sample_predicted_indices = torch.stack(sample_predicted_indices) all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _encode( self, embedder: TextFieldEmbedder, tokens: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: src = embedder(tokens) * math.sqrt(self._ndim) src = src.transpose(0, 1) src = self.pos_encoder(src) mask = util.get_text_field_mask(tokens) mask = (mask == 0) return src, mask def _decoder_step_by_step( self, memory: torch.Tensor, memory_key_padding_mask: torch.Tensor, target_vocab_mask: torch.Tensor = None, max_target_len: int = None) -> Tuple[torch.Tensor, torch.Tensor]: batch_size = memory.size(1) if getattr(self, "target_limit_decode_steps", False) and max_target_len is not None: num_decoding_steps = min(self._max_decoding_steps, max_target_len) print('decoding steps: ', num_decoding_steps) else: num_decoding_steps = self._max_decoding_steps last_predictions = memory.new_full( (batch_size, ), fill_value=self._start_index).long() step_predictions: List[torch.Tensor] = [] all_predicts = memory.new_full((batch_size, num_decoding_steps), fill_value=0).long() for timestep in range(num_decoding_steps): all_predicts[:, timestep] = last_predictions tgt, tgt_key_padding_mask = self._encode( self._target_embedder, {"tokens": all_predicts[:, :timestep + 1]}) tgt_mask = self.generate_square_subsequent_mask(timestep + 1).to( memory.device) output = self._transformer.decoder( tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask) output_projections = self._output_projection_layer(output) if target_vocab_mask is not None: output_projections += target_vocab_mask class_probabilities = F.softmax(output_projections, dim=-1) _, predicted_classes = torch.max(class_probabilities, -1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes[timestep, :] step_predictions.append(last_predictions) if ((last_predictions == self._end_index) + (last_predictions == self._pad_index)).all(): break # shape: (num_decoding_steps, batch_size) predictions = torch.stack(step_predictions) return predictions, class_probabilities @staticmethod def _get_loss(logits: torch.FloatTensor, targets: torch.LongTensor, target_mask: torch.FloatTensor) -> torch.Tensor: logits = logits.contiguous() # shape: (batch_size, num_decoding_steps) relevant_targets = targets.contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask.contiguous() return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu: all_metrics.update(self._bleu.get_metric(reset=reset)) all_metrics['seq_acc'] = self._seq_acc.get_metric(reset=reset) return all_metrics def load_state_dict(self, state_dict, strict=True): new_state_dict = {} for k, v in state_dict.items(): if k.startswith('module.'): new_state_dict[k[len('module.'):]] = v else: new_state_dict[k] = v super(MyTransformer, self).load_state_dict(new_state_dict, strict)
class FactParaphraseSeq2Seq(Model): """ Given facts and dialog acts, it generates the paraphrased message. TODO: add dialog & dialog acts history This implementation is based off the default SimpleSeq2Seq model, which takes a sequence, encodes it, and then uses the encoded representations to decode another sequence. """ def __init__( self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, source_encoder: Seq2SeqEncoder, max_decoding_steps: int, dialog_acts_encoder: FeedForward = None, attention: Attention = None, attention_function: SimilarityFunction = None, n_dialog_acts: int = None, beam_size: int = None, target_namespace: str = "tokens", target_embedding_dim: int = None, scheduled_sampling_ratio: float = 0.0, use_bleu: bool = True, use_dialog_acts: bool = True, regularizers: Optional[RegularizerApplicator] = None, ) -> None: super().__init__(vocab, regularizers) self._target_namespace = target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first # timestep of decoding, and end symbol as a way to indicate the end # of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) self._bleu = BLEU(exclude_indices={ pad_index, self._end_index, self._start_index }) else: self._bleu = None # At prediction time, we use a beam search to find the most # likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source (Facts) vocab tokens. self._source_embedder = source_embedder # Encodes the sequence of source embeddings into a sequence of hidden states. self._source_encoder = source_encoder if use_dialog_acts: # Dense embedding of dialog acts. da_embedding_dim = dialog_acts_encoder.get_input_dim() self._dialog_acts_embedder = EmbeddingBag(n_dialog_acts, da_embedding_dim) # Encodes dialog acts self._dialog_acts_encoder = dialog_acts_encoder else: self._dialog_acts_embedder = None self._dialog_acts_encoder = None num_classes = self.vocab.get_vocab_size(self._target_namespace) # Attention mechanism applied to the encoder output for each step. if attention: if attention_function: raise ConfigurationError( "You can only specify an attention module or an " "attention function, but not both.") self._attention = attention elif attention_function: self._attention = LegacyAttention(attention_function) else: self._attention = None # Dense embedding of vocab words in the target space. target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim( ) self._target_embedder = Embedding(num_classes, target_embedding_dim) # Decoder output dim needs to be the same as the encoder output dim # since we initialize the hidden state of the decoder with the final # hidden state of the encoder. self._encoder_output_dim = self._source_encoder.get_output_dim() if use_dialog_acts: self._merge_encoder = Sequential( Linear( self._source_encoder.get_output_dim() + self._dialog_acts_encoder.get_output_dim(), self._encoder_output_dim, )) self._decoder_output_dim = self._encoder_output_dim if self._attention: # If using attention, a weighted average over encoder outputs will # be concatenated to the previous target embedding to form the input # to the decoder at each time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim else: # Otherwise, the input to the decoder is just the previous target embedding. self._decoder_input_dim = target_embedding_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. # TODO (pradeep): Do not hardcode decoder cell type. self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. """ # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections( last_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def forward( self, # type: ignore source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None, dialog_acts: Optional[torch.Tensor] = None, sender: Optional[torch.Tensor] = None, metadata: Optional[Dict] = None, ) -> Dict[str, torch.Tensor]: """ Make foward pass with decoder logic for producing the entire target sequence. """ source_state, dialog_acts_state = self._encode(source_tokens, dialog_acts) if target_tokens: state = self._init_decoder_state(source_state, dialog_acts_state) # The `_forward_loop` decodes the input sequence and # computes the loss during training and validation. output_dict = self._forward_loop(state, target_tokens) else: output_dict = {} if not self.training: state = self._init_decoder_state(source_state, dialog_acts_state) predictions = self._forward_beam_search(state) output_dict.update(predictions) if target_tokens and self._bleu: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._bleu(best_predictions, target_tokens["tokens"]) return output_dict @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. """ predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence # in the batch but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _encode( self, source_tokens: Dict[str, torch.Tensor], dialog_acts: torch.Tensor = None ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: # Encode source tokens source_state = self._encode_source_tokens(source_tokens) # Encode dialog acts if self._dialog_acts_encoder: dialog_acts_state = self._encode_dialog_acts(dialog_acts) else: dialog_acts_state = None return (source_state, dialog_acts_state) def _encode_source_tokens( self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_input = self._source_embedder(source_tokens) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = self._source_encoder(embedded_input, source_mask) return {"source_mask": source_mask, "encoder_outputs": encoder_outputs} def _encode_dialog_acts(self, dialog_acts: torch.Tensor) -> torch.Tensor: # shape: (batch_size, dialog_acts_embeddings_size) embedded_dialog_acts = self._dialog_acts_embedder(dialog_acts) # shape: (batch_size, dim_encoder) dialog_acts_state = self._dialog_acts_encoder(embedded_dialog_acts) return dialog_acts_state def _init_decoder_state( self, source_state: Dict[str, torch.Tensor], dialog_acts_state: torch.Tensor = None, ) -> Dict[str, torch.Tensor]: batch_size = source_state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( source_state["encoder_outputs"], source_state["source_mask"], self._source_encoder.is_bidirectional(), ) # Condition the source tokens state with dialog acts state if self._dialog_acts_encoder: final_encoder_output = self._merge_encoder( torch.cat([final_encoder_output, dialog_acts_state], dim=1)) # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) source_state["decoder_hidden"] = final_encoder_output # shape: (batch_size, decoder_output_dim) source_state["decoder_context"] = source_state[ "encoder_outputs"].new_zeros(batch_size, self._decoder_output_dim) return source_state def _forward_loop( self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None, ) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size()[0] if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full((batch_size, ), fill_value=self._start_index) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand( 1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of # 1 - _scheduled_sampling_ratio during training. # shape: (batch_size,) input_choices = last_predictions elif not target_tokens: # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] # shape: (batch_size, num_classes) output_projections, state = self._prepare_output_projections( input_choices, state) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape: (batch_size, num_classes) class_probabilities = F.softmax(output_projections, dim=-1) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(class_probabilities, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) output_dict = {"predictions": predictions} if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss return output_dict def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _prepare_output_projections( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (group_size, decoder_output_dim) decoder_context = state["decoder_context"] # shape: (group_size, target_embedding_dim) embedded_input = self._target_embedder(last_predictions) if self._attention: # shape: (group_size, encoder_output_dim) attended_input = self._prepare_attended_input( decoder_hidden, encoder_outputs, source_mask) # shape: (group_size, decoder_output_dim + target_embedding_dim) decoder_input = torch.cat((attended_input, embedded_input), -1) else: # shape: (group_size, target_embedding_dim) decoder_input = embedded_input # shape (decoder_hidden): (batch_size, decoder_output_dim) # shape (decoder_context): (batch_size, decoder_output_dim) decoder_hidden, decoder_context = self._decoder_cell( decoder_input, (decoder_hidden, decoder_context)) state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context # shape: (group_size, num_classes) output_projections = self._output_projection_layer(decoder_hidden) return output_projections, state def _prepare_attended_input( self, decoder_hidden_state: torch.LongTensor = None, encoder_outputs: torch.LongTensor = None, encoder_outputs_mask: torch.LongTensor = None, ) -> torch.Tensor: """Apply attention over encoder outputs and decoder state.""" # Ensure mask is also a FloatTensor. Or else the multiplication within # attention will complain. # shape: (batch_size, max_input_sequence_length) encoder_outputs_mask = encoder_outputs_mask.float() # shape: (batch_size, max_input_sequence_length) input_weights = self._attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask) # shape: (batch_size, encoder_output_dim) attended_input = util.weighted_sum(encoder_outputs, input_weights) return attended_input @staticmethod def _get_loss( logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor, ) -> torch.Tensor: # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu and not self.training: all_metrics.update(self._bleu.get_metric(reset=reset)) return all_metrics
class LatentAignmentCTC(Model): def __init__( self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, upsample: torch.nn.Module = None, net: Seq2SeqEncoder = None, target_namespace: str = "target_tokens", target_embedding_dim: int = None, use_bleu: bool = True, ) -> None: super(LatentAignmentCTC, self).__init__(vocab) self._target_namespace = target_namespace self._pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) self._blank_index = self.vocab.get_token_index(SPECIAL_BLANK_TOKEN, self._target_namespace) if use_bleu: self._bleu = BLEU(exclude_indices={self._pad_index, self._blank_index}) else: self._bleu = None self._source_embedder = source_embedder source_embedding_dim = source_embedder.get_output_dim() self._upsample = upsample or LinearUpsample(source_embedding_dim, s = 3) self._net = net or StackedSelfAttentionEncoder(input_dim = source_embedding_dim, hidden_dim = 128, projection_dim = 128, feedforward_hidden_dim = 512, num_layers = 4, num_attention_heads = 4) num_classes = self.vocab.get_vocab_size(self._target_namespace) target_embedding_dim = self._net.get_output_dim() self._output_projection = torch.nn.Linear(target_embedding_dim, num_classes) @overrides def forward( self, # type: ignore source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None, ) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_input = self._source_embedder(source_tokens) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) # source_upsampled : shape : (batch_size, max_input_sequence_length, encoder_input_dim * self.s) # source_mask_upsampled : shape : (batch_size, max_input_sequence_length) source_upsampled, source_mask_upsampled = self._upsample(embedded_input, source_mask) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) net_output = self._net(source_upsampled, source_mask_upsampled) output_dict = {"source_mask_upsampled": source_mask_upsampled, "net_output": net_output} alignment_logits = self._output_projection(net_output) output_dict["alignment_logits"] = alignment_logits if target_tokens: # Compute loss. loss = self._get_loss(output_dict, target_tokens) output_dict["loss"] = loss if not self.training: alignments = alignment_logits.detach().cpu().argmax(2) predictions = self.beta_inverse(alignments) output_dict["predictions"] = predictions if target_tokens and self._bleu: self._bleu(output_dict['predictions'], target_tokens["tokens"]) #output_dict = self.decode(output_dict) #print(output_dict["predicted_tokens"]) return output_dict # TODO: too cheap. need pallalel processing def beta_inverse(self, a:torch.Tensor): """ a : size (batch, sequence) """ max_length = a.size(1) outputs = [] for sequence in a.tolist(): output = [] for token in sequence: if token == self._blank_index: continue elif len(output) == 0: output.append(token) continue elif token == output[-1]: continue else: output.append(token) pad_list = [self._pad_index] * (max_length - len(output)) outputs.append(output + pad_list) return torch.LongTensor(outputs) # @staticmethod def _get_loss(self, output_dict: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor]) -> torch.Tensor: targets = target_tokens["tokens"] target_mask = util.get_text_field_mask(target_tokens) # shape: (batch_size, input_length, target_size) alignment_logits = output_dict["alignment_logits"] # shape: (batch_size, input_length) source_mask_upsampled = output_dict["source_mask_upsampled"] #return util.sequence_cross_entropy_with_logits(alignment_logits, targets, source_mask_upsampled) return sequence_ctc_loss_with_logits(alignment_logits, source_mask_upsampled, targets, target_mask, self._blank_index) @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. """ predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: indices = list(indices) # remove pad if self._pad_index in indices: indices = indices[: indices.index(self._pad_index)] # lookup predicted_tokens = [ self.vocab.get_token_from_index(x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) # provide "tokens" and "predicted_tokens" for output. output_dict["predicted_tokens"] = all_predicted_tokens del output_dict["alignment_logits"], output_dict['source_mask_upsampled'], output_dict['net_output'] return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu and not self.training: all_metrics.update(self._bleu.get_metric(reset=reset)) return all_metrics
class Bart(Model): """ BART model from the paper "BART: Denosing Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension" (https://arxiv.org/abs/1910.13461). The Bart model here uses a language modeling head and thus can be used for text generation. """ def __init__( self, model_name: str, vocab: Vocabulary, indexer: PretrainedTransformerIndexer = None, max_decoding_steps: int = 140, beam_size: int = 4, encoder: Seq2SeqEncoder = None, ): """ # Parameters model_name : `str`, required Name of the pre-trained BART model to use. Available options can be found in `transformers.modeling_bart.BART_PRETRAINED_MODEL_ARCHIVE_MAP`. vocab : `Vocabulary`, required Vocabulary containing source and target vocabularies. indexer : `PretrainedTransformerIndexer`, optional (default = `None`) Indexer to be used for converting decoded sequences of ids to to sequences of tokens. max_decoding_steps : `int`, optional (default = `128`) Number of decoding steps during beam search. beam_size : `int`, optional (default = `5`) Number of beams to use in beam search. The default is from the BART paper. encoder : `Seq2SeqEncoder`, optional (default = `None`) Encoder to used in BART. By default, the original BART encoder is used. """ super().__init__(vocab) self.bart = BartForConditionalGeneration.from_pretrained(model_name) self._indexer = indexer or PretrainedTransformerIndexer( model_name, namespace="tokens") self._start_id = self.bart.config.bos_token_id # CLS self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id self._end_id = self.bart.config.eos_token_id # SEP self._pad_id = self.bart.config.pad_token_id # PAD self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_id, max_steps=max_decoding_steps, beam_size=beam_size or 1) self._rouge = ROUGE( exclude_indices={self._start_id, self._pad_id, self._end_id}) self._bleu = BLEU( exclude_indices={self._start_id, self._pad_id, self._end_id}) # Replace bart encoder with given encoder. We need to extract the two embedding layers so that # we can use them in the encoder wrapper if encoder is not None: assert (encoder.get_input_dim() == encoder.get_output_dim() == self.bart.config.hidden_size) self.bart.model.encoder = _BartEncoderWrapper( encoder, self.bart.model.encoder.embed_tokens, self.bart.model.encoder.embed_positions, ) @overrides def forward( self, source_tokens: TextFieldTensors, target_tokens: TextFieldTensors = None) -> Dict[str, torch.Tensor]: """ Performs the forward step of Bart. # Parameters source_tokens : `TextFieldTensors`, required The source tokens for the encoder. We assume they are stored under the `tokens` key. target_tokens : `TextFieldTensors`, optional (default = `None`) The target tokens for the decoder. We assume they are stored under the `tokens` key. If no target tokens are given, the source tokens are shifted to the right by 1. # Returns `Dict[str, torch.Tensor]` During training, this dictionary contains the `decoder_logits` of shape `(batch_size, max_target_length, target_vocab_size)` and the `loss`. During inference, it contains `predictions` of shape `(batch_size, max_decoding_steps)` and `log_probabilities` of shape `(batch_size,)`. """ inputs = source_tokens targets = target_tokens input_ids, input_mask = inputs["tokens"]["token_ids"], inputs[ "tokens"]["mask"] outputs = {} # If no targets are provided, then shift input to right by 1. Bart already does this internally # but it does not use them for loss calculation. if targets is not None: target_ids, target_mask = targets["tokens"]["token_ids"], targets[ "tokens"]["mask"] else: target_ids = input_ids[:, 1:] target_mask = input_mask[:, 1:] if self.training: decoder_logits = self.bart( input_ids=input_ids, attention_mask=input_mask, decoder_input_ids=target_ids[:, :-1].contiguous(), decoder_attention_mask=target_mask[:, :-1].contiguous(), use_cache=False, )[0] outputs["decoder_logits"] = decoder_logits # The BART paper mentions label smoothing of 0.1 for sequence generation tasks outputs["loss"] = sequence_cross_entropy_with_logits( decoder_logits, target_ids[:, 1:].contiguous(), target_mask[:, 1:].contiguous(), label_smoothing=0.1, average="token", ) else: # Use decoder start id and start of sentence to start decoder initial_decoder_ids = torch.tensor( [[self._decoder_start_id, self._start_id]], dtype=input_ids.dtype, device=input_ids.device, ).repeat(input_ids.shape[0], 1) inital_state = { "input_ids": input_ids, "input_mask": input_mask, "encoder_states": None, } beam_result = self._beam_search.search(initial_decoder_ids, inital_state, self.take_step) predictions = beam_result[0] max_pred_indices = (beam_result[1].argmax(dim=-1).view( -1, 1, 1).expand(-1, -1, predictions.shape[-1])) predictions = predictions.gather( dim=1, index=max_pred_indices).squeeze(dim=1) self._rouge(predictions, target_ids) self._bleu(predictions, target_ids) outputs["predictions"] = predictions outputs["log_probabilities"] = (beam_result[1].gather( dim=-1, index=max_pred_indices[..., 0]).squeeze(dim=-1)) self.make_output_human_readable(outputs) return outputs @staticmethod def _decoder_cache_to_dict(decoder_cache): cache_dict = {} for layer_index, layer_cache in enumerate(decoder_cache): for attention_name, attention_cache in layer_cache.items(): for tensor_name, cache_value in attention_cache.items(): key = (layer_index, attention_name, tensor_name) cache_dict[key] = cache_value return cache_dict @staticmethod def _dict_to_decoder_cache(cache_dict): decoder_cache = [] for key, cache_value in cache_dict.items(): # Split key and extract index and dict keys layer_idx, attention_name, tensor_name = key # Extend decoder_cache to fit layer_idx + 1 layers decoder_cache = decoder_cache + [ {} for _ in range(layer_idx + 1 - len(decoder_cache)) ] cache = decoder_cache[layer_idx] if attention_name not in cache: cache[attention_name] = {} assert tensor_name not in cache[attention_name] cache[attention_name][tensor_name] = cache_value return decoder_cache def take_step(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], step: int) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take step during beam search. # Parameters last_predictions : `torch.Tensor` The predicted token ids from the previous step. Shape: `(group_size,)` state : `Dict[str, torch.Tensor]` State required to generate next set of predictions step : `int` The time step in beam search decoding. # Returns `Tuple[torch.Tensor, Dict[str, torch.Tensor]]` A tuple containing logits for the next tokens of shape `(group_size, target_vocab_size)` and an updated state dictionary. """ if len(last_predictions.shape) == 1: last_predictions = last_predictions.unsqueeze(-1) # Only the last predictions are needed for the decoder, but we need to pad the decoder ids # to not mess up the positional embeddings in the decoder. padding_size = 0 if step > 0: padding_size = step + 1 padding = torch.full( (last_predictions.shape[0], padding_size), self._pad_id, dtype=last_predictions.dtype, device=last_predictions.device, ) last_predictions = torch.cat([padding, last_predictions], dim=-1) decoder_cache = None decoder_cache_dict = { k: (state[k].contiguous() if state[k] is not None else None) for k in state if k not in {"input_ids", "input_mask", "encoder_states"} } if len(decoder_cache_dict) != 0: decoder_cache = self._dict_to_decoder_cache(decoder_cache_dict) log_probabilities = None for i in range(padding_size, last_predictions.shape[1]): encoder_outputs = ((state["encoder_states"], ) if state["encoder_states"] is not None else None) outputs = self.bart( input_ids=state["input_ids"], attention_mask=state["input_mask"], encoder_outputs=encoder_outputs, decoder_input_ids=last_predictions[:, :i + 1], past_key_values=decoder_cache, generation_mode=True, use_cache=True, ) decoder_log_probabilities = F.log_softmax(outputs[0][:, 0], dim=-1) if log_probabilities is None: log_probabilities = decoder_log_probabilities else: idx = last_predictions[:, i].view(-1, 1) log_probabilities = decoder_log_probabilities + log_probabilities.gather( dim=-1, index=idx) decoder_cache = outputs[1] state["encoder_states"] = outputs[2] if decoder_cache is not None: decoder_cache_dict = self._decoder_cache_to_dict(decoder_cache) state.update(decoder_cache_dict) return log_probabilities, state @overrides def make_output_human_readable( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: """ # Parameters output_dict : `Dict[str, torch.Tensor]` A dictionary containing a batch of predictions with key `predictions`. The tensor should have shape `(batch_size, max_sequence_length)` # Returns `Dict[str, Any]` Original `output_dict` with an additional `predicted_tokens` key that maps to a list of lists of tokens. """ predictions = output_dict["predictions"] predicted_tokens = [None] * predictions.shape[0] for i in range(predictions.shape[0]): predicted_tokens[i] = self._indexer.indices_to_tokens( {"token_ids": predictions[0].tolist()}, self.vocab) output_dict["predicted_tokens"] = predicted_tokens return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics: Dict[str, float] = {} if not self.training: metrics.update(self._rouge.get_metric(reset=reset)) metrics.update(self._bleu.get_metric(reset=reset)) return metrics
def __init__(self, vocab: Vocabulary, bidaf_model: BidirectionalAttentionFlowModified, source_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, attention: Attention, beam_size: int, max_decoding_steps: int, target_embedding_dim: int = 30, copy_token: str = "@COPY@", source_namespace: str = "source_tokens", target_namespace: str = "target_tokens", tensor_based_metric: Metric = None, token_based_metric: Metric = None, initializer: InitializerApplicator = InitializerApplicator(), dropout: float = 0.0, pretrained_bidaf: bool = False) -> None: super().__init__(vocab) if pretrained_bidaf: params = Params.from_file("./temp/bidaf_baseline/config.json") vocab = Vocabulary.from_files("./temp/bidaf_baseline/vocabulary") self.bidaf_model = Model.from_params(vocab=vocab, params=params.pop('model')) map_location = None if torch.cuda.is_available() else 'cpu' with open("./temp/bidaf_baseline/best.th", 'rb') as f: self.bidaf_model.load_state_dict( torch.load(f, map_location=map_location)) else: self.bidaf_model = bidaf_model self._source_namespace = source_namespace self._target_namespace = target_namespace self._src_start_index = self.vocab.get_token_index( START_SYMBOL, self._source_namespace) self._src_end_index = self.vocab.get_token_index( END_SYMBOL, self._source_namespace) self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) self._oov_index = self.vocab.get_token_index(self.vocab._oov_token, self._target_namespace) # pylint: disable=protected-access self._pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._copy_index = self.vocab.add_token_to_namespace( copy_token, self._target_namespace) self._tensor_based_metric = tensor_based_metric or \ BLEU(exclude_indices={self._pad_index, self._end_index, self._start_index}) self._token_based_metric = token_based_metric self._target_vocab_size = self.vocab.get_vocab_size( self._target_namespace) # Encoding modules. self._source_embedder = source_embedder self._encoder = encoder # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. # We arbitrarily set the decoder's input dimension to be the same as the output dimension. self.encoder_output_dim = self._encoder.get_output_dim() self.decoder_output_dim = self.encoder_output_dim self.decoder_input_dim = self.decoder_output_dim modeling_dim = self.bidaf_model._modeling_layer.get_output_dim() self._init_decoder_projection = Linear( self.encoder_output_dim + modeling_dim, self.decoder_output_dim) target_vocab_size = self.vocab.get_vocab_size(self._target_namespace) # The decoder input will be a function of the embedding of the previous predicted token, # an attended encoder hidden state called the "attentive read", and another # weighted sum of the encoder hidden state called the "selective read". # While the weights for the attentive read are calculated by an `Attention` module, # the weights for the selective read are simply the predicted probabilities # corresponding to each token in the source sentence that matches the target # token from the previous timestep. self._target_embedder = Embedding(target_vocab_size, target_embedding_dim) self._attention = attention self._input_projection_layer = Linear( target_embedding_dim + self.encoder_output_dim * 2, self.decoder_input_dim) # We then run the projected decoder input through an LSTM cell to produce # the next hidden state. self._decoder_cell = LSTMCell(self.decoder_input_dim, self.decoder_output_dim) # We create a "generation" score for each token in the target vocab # with a linear projection of the decoder hidden state. self._output_generation_layer = Linear(self.decoder_output_dim, target_vocab_size) # We create a "copying" score for each source token by applying a non-linearity # (tanh) to a linear projection of the encoded hidden state for that token, # and then taking the dot product of the result with the decoder hidden state. self._output_copying_layer = Linear(self.encoder_output_dim, self.decoder_output_dim) # At prediction time, we'll use a beam search to find the best target sequence. self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x initializer(self)
def __init__( self, vocab: Vocabulary, decoder_net: DecoderNet, max_decoding_steps: int, target_embedder: Embedding, target_namespace: str = "tokens", tie_output_embedding: bool = False, scheduled_sampling_ratio: float = 0, label_smoothing_ratio: Optional[float] = None, beam_size: int = 4, tensor_based_metric: Metric = None, token_based_metric: Metric = None, bleu_exclude_tokens: List = [], ) -> None: super().__init__(target_embedder) self._vocab = vocab # Decodes the sequence of encoded hidden states into e new sequence of hidden states. self._decoder_net = decoder_net self._max_decoding_steps = max_decoding_steps self._target_namespace = target_namespace self._label_smoothing_ratio = label_smoothing_ratio # At prediction time, we use a beam search to find the most likely sequence of target tokens. # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self._vocab.get_token_index( START_SYMBOL, self._target_namespace ) self._end_index = self._vocab.get_token_index( END_SYMBOL, self._target_namespace ) self._beam_search = BeamSearch( self._end_index, max_steps=max_decoding_steps, beam_size=beam_size ) target_vocab_size = self._vocab.get_vocab_size(self._target_namespace) if ( self.target_embedder.get_output_dim() != self._decoder_net.target_embedding_dim ): raise ConfigurationError( "Target Embedder output_dim doesn't match decoder module's input." ) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear( self._decoder_net.get_output_dim(), target_vocab_size ) if tie_output_embedding: if ( self._output_projection_layer.weight.shape != self.target_embedder.weight.shape ): raise ConfigurationError( "Can't tie embeddings with output linear layer, due to shape mismatch" ) self._output_projection_layer.weight = self.target_embedder.weight # These metrics will be updated during training and validation if isinstance(tensor_based_metric, BLEU): pad_index = self._vocab.get_token_index( self._vocab._padding_token, self._target_namespace ) new_exclude_indices = set([pad_index]) for token in bleu_exclude_tokens: new_exclude_indices.add( self._vocab.get_token_index(token, self._target_namespace) ) new_exclude_indices.update(tensor_based_metric._exclude_indices) logger.info( f"Reconstruct BLEU to exclude indices {' '.join(map(str, new_exclude_indices))}" ) self._tensor_based_metric = BLEU( tensor_based_metric._ngram_weights, new_exclude_indices ) else: self._tensor_based_metric = tensor_based_metric self._token_based_metric = token_based_metric self._scheduled_sampling_ratio = scheduled_sampling_ratio
class BleuAutoRegressiveSeqDecoder(SeqDecoder): """ An autoregressive decoder that can be used for most seq2seq tasks. Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (`tokens`) or the target tokens can have a different namespace, in which case it needs to be specified as `target_namespace`. decoder_net : ``DecoderNet``, required Module that contains implementation of neural network for decoding output elements max_decoding_steps : ``int``, required Maximum length of decoded sequences. target_embedder : ``Embedding``, required Embedder for target tokens. target_namespace : ``str``, optional (default = 'tokens') If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. beam_size : ``int``, optional (default = 4) Width of the beam for beam search. tensor_based_metric : ``Metric``, optional (default = None) A metric to track on validation data that takes raw tensors when its called. This metric must accept two arguments when called: a batched tensor of predicted token indices, and a batched tensor of gold token indices. token_based_metric : ``Metric``, optional (default = None) A metric to track on validation data that takes lists of lists of tokens as input. This metric must accept two arguments when called, both of type `List[List[str]]`. The first is a predicted sequence for each item in the batch and the second is a gold sequence for each item in the batch. scheduled_sampling_ratio : ``float`` optional (default = 0) Defines ratio between teacher forced training and real output usage. If its zero (teacher forcing only) and `decoder_net`supports parallel decoding, we get the output predictions in a single forward pass of the `decoder_net`. """ def __init__( self, vocab: Vocabulary, decoder_net: DecoderNet, max_decoding_steps: int, target_embedder: Embedding, target_namespace: str = "tokens", tie_output_embedding: bool = False, scheduled_sampling_ratio: float = 0, label_smoothing_ratio: Optional[float] = None, beam_size: int = 4, tensor_based_metric: Metric = None, token_based_metric: Metric = None, bleu_exclude_tokens: List = [], ) -> None: super().__init__(target_embedder) self._vocab = vocab # Decodes the sequence of encoded hidden states into e new sequence of hidden states. self._decoder_net = decoder_net self._max_decoding_steps = max_decoding_steps self._target_namespace = target_namespace self._label_smoothing_ratio = label_smoothing_ratio # At prediction time, we use a beam search to find the most likely sequence of target tokens. # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self._vocab.get_token_index( START_SYMBOL, self._target_namespace ) self._end_index = self._vocab.get_token_index( END_SYMBOL, self._target_namespace ) self._beam_search = BeamSearch( self._end_index, max_steps=max_decoding_steps, beam_size=beam_size ) target_vocab_size = self._vocab.get_vocab_size(self._target_namespace) if ( self.target_embedder.get_output_dim() != self._decoder_net.target_embedding_dim ): raise ConfigurationError( "Target Embedder output_dim doesn't match decoder module's input." ) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear( self._decoder_net.get_output_dim(), target_vocab_size ) if tie_output_embedding: if ( self._output_projection_layer.weight.shape != self.target_embedder.weight.shape ): raise ConfigurationError( "Can't tie embeddings with output linear layer, due to shape mismatch" ) self._output_projection_layer.weight = self.target_embedder.weight # These metrics will be updated during training and validation if isinstance(tensor_based_metric, BLEU): pad_index = self._vocab.get_token_index( self._vocab._padding_token, self._target_namespace ) new_exclude_indices = set([pad_index]) for token in bleu_exclude_tokens: new_exclude_indices.add( self._vocab.get_token_index(token, self._target_namespace) ) new_exclude_indices.update(tensor_based_metric._exclude_indices) logger.info( f"Reconstruct BLEU to exclude indices {' '.join(map(str, new_exclude_indices))}" ) self._tensor_based_metric = BLEU( tensor_based_metric._ngram_weights, new_exclude_indices ) else: self._tensor_based_metric = tensor_based_metric self._token_based_metric = token_based_metric self._scheduled_sampling_ratio = scheduled_sampling_ratio def _forward_beam_search( self, state: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: """ Prepare inputs for the beam search, does beam search and returns beam search results. """ batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size,), fill_value=self._start_index ) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step ) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _forward_loss( self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] ) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] # Prepare embeddings for targets. They will be used as gold embeddings during decoder training # shape: (batch_size, max_target_sequence_length, embedding_dim) target_embedding = self.target_embedder(targets) # shape: (batch_size, max_target_batch_sequence_length) target_mask = util.get_text_field_mask(target_tokens) if self._scheduled_sampling_ratio == 0 and self._decoder_net.decodes_parallel: _, decoder_output = self._decoder_net( previous_state=state, previous_steps_predictions=target_embedding[:, :-1, :], encoder_outputs=encoder_outputs, source_mask=source_mask, previous_steps_mask=target_mask[:, :-1], ) # shape: (group_size, max_target_sequence_length, num_classes) logits = self._output_projection_layer(decoder_output) else: batch_size = source_mask.size()[0] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full( (batch_size,), fill_value=self._start_index ) # shape: (steps, batch_size, target_embedding_dim) steps_embeddings = torch.Tensor([]) step_logits: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if ( self.training and torch.rand(1).item() < self._scheduled_sampling_ratio ): # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size, steps, target_embedding_dim) state["previous_steps_predictions"] = steps_embeddings # shape: (batch_size, ) effective_last_prediction = last_predictions else: # shape: (batch_size, ) effective_last_prediction = targets[:, timestep] if timestep == 0: state["previous_steps_predictions"] = torch.Tensor([]) else: # shape: (batch_size, steps, target_embedding_dim) state["previous_steps_predictions"] = target_embedding[ :, :timestep ] # shape: (batch_size, num_classes) output_projections, state = self._prepare_output_projections( effective_last_prediction, state ) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(output_projections, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes # shape: (batch_size, 1, target_embedding_dim) last_predictions_embeddings = self.target_embedder( last_predictions ).unsqueeze(1) # This step is required, since we want to keep up two different prediction history: gold and real if steps_embeddings.shape[-1] == 0: # There is no previous steps, except for start vectors in ``last_predictions`` # shape: (group_size, 1, target_embedding_dim) steps_embeddings = last_predictions_embeddings else: # shape: (group_size, steps_count, target_embedding_dim) steps_embeddings = torch.cat( [steps_embeddings, last_predictions_embeddings], 1 ) # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) # TODO: We will be using beam search to get predictions for validation, but if beam size in 1 # we could consider taking the last_predictions here and building step_predictions # and use that instead of running beam search again, if performance in validation is taking a hit output_dict = {"loss": loss} return output_dict def _prepare_output_projections( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, steps_count, decoder_output_dim) previous_steps_predictions = state.get("previous_steps_predictions") # shape: (batch_size, 1, target_embedding_dim) last_predictions_embeddings = self.target_embedder(last_predictions).unsqueeze( 1 ) if ( previous_steps_predictions is None or previous_steps_predictions.shape[-1] == 0 ): # There is no previous steps, except for start vectors in ``last_predictions`` # shape: (group_size, 1, target_embedding_dim) previous_steps_predictions = last_predictions_embeddings else: # shape: (group_size, steps_count, target_embedding_dim) previous_steps_predictions = torch.cat( [previous_steps_predictions, last_predictions_embeddings], 1 ) decoder_state, decoder_output = self._decoder_net( previous_state=state, encoder_outputs=encoder_outputs, source_mask=source_mask, previous_steps_predictions=previous_steps_predictions, ) state["previous_steps_predictions"] = previous_steps_predictions # Update state with new decoder state, override previous state state.update(decoder_state) if self._decoder_net.decodes_parallel: decoder_output = decoder_output[:, -1, :] # shape: (group_size, num_classes) output_projections = self._output_projection_layer(decoder_output) return output_projections, state def _get_loss( self, logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor, ) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits( logits, relevant_targets, relevant_mask, label_smoothing=self._label_smoothing_ratio, ) def get_output_dim(self): return self._decoder_net.get_output_dim() def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections( last_predictions, state ) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if not self.training: if self._tensor_based_metric is not None: all_metrics.update( self._tensor_based_metric.get_metric(reset=reset) # type: ignore ) if self._token_based_metric is not None: all_metrics.update(self._token_based_metric.get_metric(reset=reset)) # type: ignore return all_metrics @overrides def forward( self, encoder_out: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None, ) -> Dict[str, torch.Tensor]: state = encoder_out if target_tokens: decoder_init_state = self._decoder_net.init_decoder_state(state) decoder_init_state.update(state) output_dict = self._forward_loss(decoder_init_state, target_tokens) else: output_dict = {} if not self.training: decoder_init_state = self._decoder_net.init_decoder_state(state) decoder_init_state.update(state) predictions = self._forward_beam_search(decoder_init_state) output_dict.update(predictions) if target_tokens: if self._tensor_based_metric is not None: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._tensor_based_metric( # type: ignore best_predictions, target_tokens["tokens"] ) if self._token_based_metric is not None: output_dict = self.post_process(output_dict) predicted_tokens = output_dict["predicted_tokens"] self._token_based_metric( # type: ignore predicted_tokens, self.indices_to_tokens(target_tokens["tokens"][:, 1:]), ) return output_dict @overrides def post_process( self, output_dict: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: """ This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ predicted_indices = output_dict["predictions"] all_predicted_tokens = self.indices_to_tokens(predicted_indices) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def indices_to_tokens(self, batch_indeces: numpy.ndarray) -> List[List[str]]: if not isinstance(batch_indeces, numpy.ndarray): batch_indeces = batch_indeces.detach().cpu().numpy() all_tokens = [] for indices in batch_indeces: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[: indices.index(self._end_index)] tokens = [ self._vocab.get_token_from_index(x, namespace=self._target_namespace) for x in indices ] all_tokens.append(tokens) return all_tokens
def __init__( self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, target_embedder: TextFieldEmbedder, source_encoder: Seq2SeqEncoder, target_encoder: Seq2SeqEncoder, #instance_reference_similarity_function:MatrixAttention, #target_instance_ref_similarity_function:MatrixAttention, max_decoding_steps: int, attention: Attention = None, s2s_attention: Attention = None, t2t_attention: Attention = None, beam_size: int = None, target_namespace: str = "tokens", scheduled_sampling_ratio: float = 0., use_bleu: bool = True) -> None: super(AssociativeSeq2SeqChainedAttention, self).__init__(vocab) self._target_namespace = target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={ pad_index, self._end_index, self._start_index }) else: self._bleu = None # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source vocab tokens. self._source_embedder = source_embedder # Encodes the sequence of source embeddings into a sequence of hidden states. self._source_encoder = source_encoder self._target_encoder = target_encoder self._encoder_output_dim = self._target_encoder.get_output_dim() self._decoder_output_dim = self._encoder_output_dim target_embedding_dim = target_embedder.get_output_dim() if attention: self._attention = attention self._s2s_attention = s2s_attention self._t2t_attention = t2t_attention self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim else: raise NotImplementedError num_classes = self.vocab.get_vocab_size(self._target_namespace) self._target_embedder = target_embedder self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) self._output_projection_layer = Linear(self._decoder_output_dim, num_classes)
class AssociativeSeq2SeqChainedAttention(Model): def __init__( self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, target_embedder: TextFieldEmbedder, source_encoder: Seq2SeqEncoder, target_encoder: Seq2SeqEncoder, #instance_reference_similarity_function:MatrixAttention, #target_instance_ref_similarity_function:MatrixAttention, max_decoding_steps: int, attention: Attention = None, s2s_attention: Attention = None, t2t_attention: Attention = None, beam_size: int = None, target_namespace: str = "tokens", scheduled_sampling_ratio: float = 0., use_bleu: bool = True) -> None: super(AssociativeSeq2SeqChainedAttention, self).__init__(vocab) self._target_namespace = target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={ pad_index, self._end_index, self._start_index }) else: self._bleu = None # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source vocab tokens. self._source_embedder = source_embedder # Encodes the sequence of source embeddings into a sequence of hidden states. self._source_encoder = source_encoder self._target_encoder = target_encoder self._encoder_output_dim = self._target_encoder.get_output_dim() self._decoder_output_dim = self._encoder_output_dim target_embedding_dim = target_embedder.get_output_dim() if attention: self._attention = attention self._s2s_attention = s2s_attention self._t2t_attention = t2t_attention self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim else: raise NotImplementedError num_classes = self.vocab.get_vocab_size(self._target_namespace) self._target_embedder = target_embedder self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections( last_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def forward( self, # type: ignore ref_source_tokens: Dict[str, torch.LongTensor], instance_source_tokens: Dict[str, torch.LongTensor], ref_target_tokens: Dict[str, torch.LongTensor], instance_target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: state = self._encode(ref_target_tokens, ref_source_tokens, instance_source_tokens) if instance_target_tokens: state = self._init_decoder_state(state) # The `_forward_loop` decodes the input sequence and computes the loss during training # and validation. output_dict = self._forward_loop(state, instance_target_tokens) else: output_dict = {} if not self.training: state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) if instance_target_tokens and self._bleu: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._bleu(best_predictions, instance_target_tokens["tokens"]) return output_dict @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _encode( self, ref_target_tokens: Dict[str, torch.Tensor], ref_source_tokens: Dict[str, torch.Tensor], instance_source_tokens: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_ref_target = self._target_embedder(ref_target_tokens) ref_target_mask = util.get_text_field_mask(ref_target_tokens) encoded_ref_target = self._target_encoder(embedded_ref_target, ref_target_mask) embedded_ref_source = self._source_embedder(ref_source_tokens) ref_source_mask = util.get_text_field_mask(ref_source_tokens) encoded_ref_source = self._source_encoder(embedded_ref_source, ref_source_mask) embedded_instance_source = self._source_embedder( instance_source_tokens) instance_source_mask = util.get_text_field_mask(instance_source_tokens) encoded_instance_source = self._source_encoder( embedded_instance_source, instance_source_mask) return { "source_mask": ref_target_mask, "encoder_outputs": encoded_ref_target, "ref_source_mask": ref_source_mask, "ref_source_encoded": embedded_ref_source, "instance_source_mask": instance_source_mask, "instance_source_encoded": embedded_instance_source } def _init_decoder_state( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( state["encoder_outputs"], state["source_mask"], self._target_encoder.is_bidirectional()) # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = final_encoder_output # shape: (batch_size, decoder_output_dim) state["decoder_context"] = state["encoder_outputs"].new_zeros( batch_size, self._decoder_output_dim) return state def _forward_loop( self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size()[0] if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full((batch_size, ), fill_value=self._start_index) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand( 1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size,) input_choices = last_predictions elif not target_tokens: # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] # shape: (batch_size, num_classes) output_projections, state = self._prepare_output_projections( input_choices, state) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape: (batch_size, num_classes) class_probabilities = F.softmax(output_projections, dim=-1) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(class_probabilities, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) output_dict = {"predictions": predictions} if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss return output_dict def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _prepare_output_projections(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # pylint: disable=line-too-long """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (group_size, decoder_output_dim) decoder_context = state["decoder_context"] # shape: (group_size, target_embedding_dim) embedded_input = self._target_embedder._token_embedders['tokens']( last_predictions) if self._attention: # shape: (group_size, encoder_output_dim) attended_input_source = self._prepare_attended_input( decoder_hidden, state['encoder_outputs'], state['source_mask'], self._t2t_attention) attended_input_ref_source = self._prepare_attended_input( attended_input_source, state['ref_source_encoded'], state['ref_source_mask'], self._attention) attended_input_instance_source = self._prepare_attended_input( attended_input_ref_source, state['instance_source_encoded'], state['instance_source_mask'], self._s2s_attention) # shape: (group_size, decoder_output_dim + target_embedding_dim) decoder_input = torch.cat( (attended_input_instance_source, embedded_input), -1) else: raise NotImplementedError # shape (decoder_hidden): (batch_size, decoder_output_dim) # shape (decoder_context): (batch_size, decoder_output_dim) decoder_hidden, decoder_context = self._decoder_cell( decoder_input, (decoder_hidden, decoder_context)) state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context # shape: (group_size, num_classes) output_projections = self._output_projection_layer(decoder_hidden) return output_projections, state def _prepare_attended_input(self, decoder_hidden_state: torch.LongTensor = None, encoder_outputs: torch.LongTensor = None, encoder_outputs_mask: torch.LongTensor = None, attention=None) -> torch.Tensor: """Apply attention over encoder outputs and decoder state.""" # Ensure mask is also a FloatTensor. Or else the multiplication within # attention will complain. # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs_mask = encoder_outputs_mask.float() # shape: (batch_size, max_input_sequence_length) input_weights = attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask) # shape: (batch_size, encoder_output_dim) attended_input = util.weighted_sum(encoder_outputs, input_weights) return attended_input @staticmethod def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu and not self.training: all_metrics.update(self._bleu.get_metric(reset=reset)) return all_metrics