class BertSeq2Seq(Model): """ This ``SimpleSeq2Seq`` class is a :class:`Model` which takes a sequence, encodes it, and then uses the encoded representations to decode another sequence. You can use this as the basis for a neural machine translation system, an abstractive summarization system, or any other common seq2seq problem. The model here is simple, but should be a decent starting place for implementing recent models for these tasks. Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (`tokens`) or the target tokens can have a different namespace, in which case it needs to be specified as `target_namespace`. source_embedder : ``TextFieldEmbedder``, required Embedder for source side sequences encoder : ``Seq2SeqEncoder``, required The encoder of the "encoder/decoder" model max_decoding_steps : ``int`` Maximum length of decoded sequences. target_namespace : ``str``, optional (default = 'target_tokens') If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. target_embedding_dim : ``int``, optional (default = source_embedding_dim) You can specify an embedding dimensionality for the target side. If not, we'll use the same value as the source embedder's. attention : ``Attention``, optional (default = None) If you want to use attention to get a dynamic summary of the encoder outputs at each step of decoding, this is the function used to compute similarity between the decoder hidden state and encoder outputs. attention_function: ``SimilarityFunction``, optional (default = None) This is if you want to use the legacy implementation of attention. This will be deprecated since it consumes more memory than the specialized attention modules. beam_size : ``int``, optional (default = None) Width of the beam for beam search. If not specified, greedy decoding is used. scheduled_sampling_ratio : ``float``, optional (default = 0.) At each timestep during training, we sample a random number between 0 and 1, and if it is not less than this value, we use the ground truth labels for the whole batch. Else, we use the predictions from the previous time step for the whole batch. If this value is 0.0 (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not using target side ground truth labels. See the following paper for more information: `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al., 2015 <https://arxiv.org/abs/1506.03099>`_. use_bleu : ``bool``, optional (default = True) If True, the BLEU metric will be calculated during validation. """ def __init__( self, vocab: Vocabulary, vocab_file: str, source_embedder: TextFieldEmbedder, max_decoding_steps: int = 500, attention: Attention = None, attention_function: SimilarityFunction = None, beam_size: int = 5, target_namespace: str = "tokens", target_embedding_dim: int = None, scheduled_sampling_ratio: float = 0., use_bleu: bool = False, initializer: InitializerApplicator = InitializerApplicator(), ) -> None: super(BertSeq2Seq, self).__init__(vocab) self._target_namespace = target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio self.bert_vocab_token_to_index, self.bert_vocab_index_to_token = self.build_vocab( vocab_file) # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.bert_vocab_token_to_index['[CLS]'] self._end_index = self.bert_vocab_token_to_index['[SEP]'] if use_bleu: pad_index = 0 # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={ pad_index, self._end_index, self._start_index }) else: self._bleu = None # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source vocab tokens. self._source_embedder = source_embedder num_classes = len(self.bert_vocab_token_to_index) # Attention mechanism applied to the encoder output for each step. if attention: if attention_function: raise ConfigurationError( "You can only specify an attention module or an " "attention function, but not both.") self._attention = attention elif attention_function: self._attention = LegacyAttention(attention_function) else: self._attention = None # Dense embedding of vocab words in the target space. target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim( ) self._target_embedder = Embedding(num_classes, target_embedding_dim) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = 1024 self._decoder_output_dim = self._encoder_output_dim if self._attention: # If using attention, a weighted average over encoder outputs will be concatenated # to the previous target embedding to form the input to the decoder at each # time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim else: # Otherwise, the input to the decoder is just the previous target embedding. self._decoder_input_dim = target_embedding_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. # TODO (pradeep): Do not hardcode decoder cell type. self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) self._input_projection_layer = Linear(self._encoder_output_dim, self._encoder_output_dim) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) initializer(self) def build_vocab(self, vocab_file): vocab_token_to_index, vocab_index_to_token = {}, {} with open(vocab_file) as f: vocabulary = f.readlines() for index, voc in enumerate(vocabulary): vocab_token_to_index[voc.strip()] = index vocab_index_to_token[index] = voc.strip() return vocab_token_to_index, vocab_index_to_token def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections( last_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def forward( self, # type: ignore source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Make foward pass with decoder logic for producing the entire target sequence. Parameters ---------- source_tokens : ``Dict[str, torch.LongTensor]`` The output of `TextField.as_array()` applied on the source `TextField`. This will be passed through a `TextFieldEmbedder` and then through an encoder. target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) Output of `Textfield.as_array()` applied on target `TextField`. We assume that the target tokens are also represented as a `TextField`. Returns ------- Dict[str, torch.Tensor] """ state = self._encode(source_tokens) if target_tokens: state = self._init_decoder_state(state) # The `_forward_loop` decodes the input sequence and computes the loss during training # and validation. output_dict = self._forward_loop(state, target_tokens) else: output_dict = {} if not self.training: state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) if target_tokens and self._bleu: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._bleu(best_predictions, target_tokens["bert"]) return output_dict @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: predicted_tokens = [] for indice in indices: indice = list(indice) # Collect indices till the first end_symbol if self._end_index in indice: indice = indice[:indice.index(self._end_index)] predicted_token = [ self.bert_vocab_index_to_token[x] for x in indice ] predicted_tokens.append(predicted_token) all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _encode( self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) # use bert to encode input del source_tokens['mask'] embedded_input = self._source_embedder(source_tokens) embedded_input = self._input_projection_layer(embedded_input) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) return { "source_mask": source_mask, "encoder_outputs": embedded_input, } def _init_decoder_state( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) final_encoder_output = state["encoder_outputs"][:, 0, :] # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = final_encoder_output # shape: (batch_size, decoder_output_dim) state["decoder_context"] = state["encoder_outputs"].new_zeros( batch_size, self._decoder_output_dim) return state def _forward_loop( self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size()[0] if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens['bert'] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full((batch_size, ), fill_value=self._start_index) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand( 1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size,) input_choices = last_predictions elif not target_tokens: # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] # shape: (batch_size, num_classes) output_projections, state = self._prepare_output_projections( input_choices, state) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape: (batch_size, num_classes) class_probabilities = F.softmax(output_projections, dim=-1) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(class_probabilities, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) output_dict = {"predictions": predictions} if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # Compute loss. del target_tokens['mask'] target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss return output_dict def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _prepare_output_projections(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[ torch.Tensor, Dict[str, torch.Tensor]]: # pylint: disable=line-too-long """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (group_size, decoder_output_dim) decoder_context = state["decoder_context"] # shape: (group_size, target_embedding_dim) embedded_input = self._target_embedder(last_predictions) if self._attention: # shape: (group_size, encoder_output_dim) attended_input = self._prepare_attended_input( decoder_hidden, encoder_outputs, source_mask) # shape: (group_size, decoder_output_dim + target_embedding_dim) decoder_input = torch.cat((attended_input, embedded_input), -1) else: # shape: (group_size, target_embedding_dim) decoder_input = embedded_input # shape (decoder_hidden): (batch_size, decoder_output_dim) # shape (decoder_context): (batch_size, decoder_output_dim) decoder_hidden, decoder_context = self._decoder_cell( decoder_input, (decoder_hidden, decoder_context)) state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context # shape: (group_size, num_classes) output_projections = self._output_projection_layer(decoder_hidden) return output_projections, state def _prepare_attended_input( self, decoder_hidden_state: torch.LongTensor = None, encoder_outputs: torch.LongTensor = None, encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor: """Apply attention over encoder outputs and decoder state.""" # Ensure mask is also a FloatTensor. Or else the multiplication within # attention will complain. # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs_mask = encoder_outputs_mask.float() # shape: (batch_size, max_input_sequence_length) input_weights = self._attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask) # shape: (batch_size, encoder_output_dim) attended_input = util.weighted_sum(encoder_outputs, input_weights) return attended_input @staticmethod def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu and not self.training: all_metrics.update(self._bleu.get_metric(reset=reset)) return all_metrics
class NmtSeq2Seq(Model): """ This ``NmtSeq2Seq`` class is an adaptation from the SimpleSeq2Seq :class:`Model` from the AllenNLP toolkit, which takes a sequence, encodes it, and then uses the encoded representations to decode another sequence. We have removed some functionality . Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (`tokens`) or the target tokens can have a different namespace, in which case it needs to be specified as `target_namespace`. source_embedder : ``TextFieldEmbedder``, required Embedder for source side sequences target_namespace : ``str``, If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. target_embedding_dim : ``int``, optional (default = source_embedding_dim) You can specify an embedding dimensionality for the target side. If not, we'll use the same value as the source embedder's. encoder : ``Seq2SeqEncoder``, required The encoder of the "encoder/decoder" model decoder : ``Dict``, required The parameters for the decoder RNN cell of the "encoder/decoder" model max_decoding_steps : ``int`` Maximum length of decoded sequences. You can specify an embedding dimensionality for the target side. If not, we'll use the same value as the source embedder's. attention : ``Dict``, optional (default = None) If you want to use attention to get a dynamic summary of the encoder outputs at each step of decoding, this is the Dict that holds parameters for the appropriate attention function to compute similarity between the decoder hidden state and encoder outputs. beam_size : ``int``, optional (default = None) Width of the beam for beam search. If not specified, greedy decoding is used. scheduled_sampling_ratio : ``float``, optional (default = 0.) At each timestep during training, we sample a random number between 0 and 1, and if it is not less than this value, we use the ground truth labels for the whole batch. Else, we use the predictions from the previous time step for the whole batch. If this value is 0.0 (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not using target side ground truth labels. See the following paper for more information: `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al., 2015 <https://arxiv.org/abs/1506.03099>`_. use_bleu : ``bool``, optional (default = True) If True, the BLEU metric will be calculated during validation. """ def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, target_namespace: str, encoder: Seq2SeqEncoder, decoder: Dict, max_decoding_steps: int, target_embedding_dim: int = None, attention: Dict = None, beam_size: int = None, scheduled_sampling_ratio: float = 0., use_bleu: bool = True, visualize_attention: bool = True) -> None: super(NmtSeq2Seq, self).__init__(vocab) self._scheduled_sampling_ratio = scheduled_sampling_ratio self._target_namespace = target_namespace # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={ pad_index, self._end_index, self._start_index }) else: self._bleu = None # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source vocab tokens. self._source_embedder = source_embedder # Encodes the sequence of source embeddings into a sequence of hidden states. self._encoder = encoder num_classes = self.vocab.get_vocab_size(self._target_namespace) # Attention mechanism params applied to the encoder output for each step. self._attention = attention self._visualize_attention = visualize_attention # Dense embedding of vocab words in the target space. target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim( ) self._target_embedder = Embedding(num_classes, target_embedding_dim) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = self._encoder.get_output_dim() # self._decoder_output_dim = self._encoder_output_dim self._decoder_input_dim = decoder["input_size"] # If using attention make sure the .jsonnet params reflect this architecture: # input_to_decoder_rnn = [prev_word + attended_context_vector] self._decoder_output_dim = decoder['hidden_size'] # We'll use an RNN cell as the recurrent cell that produces a hidden state # for the decoder at each time step. decoder_cell_type = decoder["type"] if decoder_cell_type == "gru": self._decoder_cell = GRUCell(self._decoder_input_dim, self._decoder_output_dim) elif decoder_cell_type == "lstm": self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) else: raise ValueError( "Dialogue encoder of type {} not supported yet!".format( decoder_cell_type)) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes) output_projections, state, attention_probs = self._prepare_output_projections( last_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def forward( self, # type: ignore source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Make foward pass with decoder logic for producing the entire target sequence. Parameters ---------- source_tokens : ``Dict[str, torch.LongTensor]`` The output of `TextField.as_array()` applied on the source `TextField`. This will be passed through a `TextFieldEmbedder` and then through an encoder. target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) Output of `Textfield.as_array()` applied on target `TextField`. We assume that the target tokens are also represented as a `TextField`. Returns ------- Dict[str, torch.Tensor] """ state = self._encode(source_tokens) if target_tokens: state = self._init_decoder_state(state) # The `_forward_loop` decodes the input sequence and computes the loss during training # and validation. output_dict = self._forward_loop(state, target_tokens) else: output_dict = {} if not self.training: state = self._init_decoder_state(state) if self._visualize_attention: output_dict = self._forward_loop(state, target_tokens) attention_matrix = np.array([ np.array(x.view(-1)) for x in output_dict['attention_probs_per_timestep'] ]) in_seq = [ self.vocab._index_to_token['source_tokens'][x] for x in source_tokens['source_tokens'].view(-1).data.numpy() ] out_seq = [ self.vocab._index_to_token['target_tokens'][x] for x in target_tokens['target_tokens'].view(-1).data.numpy() ] plot_attention(attention_matrix, in_seq, out_seq, 'attention_plots/' + '_'.join(in_seq[1:4])) else: predictions = self._forward_beam_search(state) output_dict.update(predictions) if target_tokens and self._bleu: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) #best_predictions = top_k_predictions[:, 0, :] self._bleu(top_k_predictions, target_tokens[self._target_namespace]) return output_dict @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _encode( self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_input = self._source_embedder(source_tokens) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = self._encoder(embedded_input, source_mask) return { "source_mask": source_mask, "encoder_outputs": encoder_outputs, } def _init_decoder_state( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( state["encoder_outputs"], state["source_mask"], self._encoder.is_bidirectional()) # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = final_encoder_output # shape: (batch_size, decoder_output_dim) state["decoder_context"] = state["encoder_outputs"].new_zeros( batch_size, self._decoder_output_dim) return state def _forward_loop( self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size()[0] if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens[self._target_namespace] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full((batch_size, ), fill_value=self._start_index) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] attention_probs_per_timestep = [] for timestep in range(num_decoding_steps): if self.training and torch.rand( 1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size,) input_choices = last_predictions elif not target_tokens: # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] # shape: (batch_size, num_classes) output_projections, state, attention_probs = self._prepare_output_projections( input_choices, state) attention_probs_per_timestep.append(attention_probs) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape: (batch_size, num_classes) class_probabilities = F.softmax(output_projections, dim=-1) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(class_probabilities, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) output_dict = { "predictions": predictions, "attention_probs_per_timestep": attention_probs_per_timestep } if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss return output_dict def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _prepare_output_projections(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[ torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: # pylint: disable=line-too-long """ Decode current state and last prediction to produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (group_size, decoder_output_dim) decoder_context = state["decoder_context"] # shape: (group_size, target_embedding_dim) embedded_input = self._target_embedder(last_predictions) context_vector, attention_probs = self._compute_attention( decoder_hidden, encoder_outputs, source_mask) decoder_input = torch.cat((embedded_input, context_vector), dim=-1) # shape (decoder_hidden): (batch_size, decoder_output_dim) # shape (decoder_context): (batch_size, decoder_output_dim) decoder_hidden, decoder_context = self._decoder_cell( decoder_input, (decoder_hidden, decoder_context)) state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context # shape: (group_size, num_classes) output_projections = self._output_projection_layer(decoder_hidden) return output_projections, state, attention_probs def _compute_attention( self, decoder_hidden_state: torch.LongTensor = None, encoder_outputs: torch.LongTensor = None, encoder_outputs_mask: torch.LongTensor = None ) -> (torch.Tensor, torch.Tensor): """Apply attention over encoder outputs and decoder state. Parameters ---------- decoder_hidden_state : ``torch.LongTensor`` A tensor of shape ``(batch_size, decoder_output_dim)``, which contains the current decoder hidden state to be used as the 'query' to the attention computation during the last time step. encoder_outputs : ``torch.LongTensor`` A tensor of shape ``(batch_size, max_input_sequence_length, encoder_output_dim)``, which contains all the encoder hidden states of the source tokens, i.e., the 'keys' to the attention computation encoder_mask : ``torch.LongTensor`` A tensor of shape (batch_size, max_input_sequence_length), which contains the mask of the encoded input. We want to avoid computing an attention score for positions of the source with zero-values (remember not all input sentences have the same length) Returns ------- (torch.Tensor, torch.Tensor) A tensor of shape (batch_size, encoder_output_dim) that contains the attended encoder outputs (aka context vector), i.e., we have ``applied`` the attention scores on the encoder hidden states. Notes ----- Don't forget to apply the final softmax over the **masked** encoder outputs! """ # Ensure mask is also a FloatTensor. Or else the multiplication within # attention will complain. # shape: (batch_size, max_input_sequence_length) encoder_outputs_mask = encoder_outputs_mask.float() attention_weights = encoder_outputs.bmm( decoder_hidden_state.unsqueeze(-1)).squeeze(-1) # Main body of attention weights computation here # decoder hidden state 1, 400 # encoder outputs 1, 14, 400 # encoder_outputs_mask = 1, 14 attention_probs = masked_softmax(attention_weights, encoder_outputs_mask) # attention weights = 1, 14 context_vector = util.weighted_sum(encoder_outputs, attention_probs) return context_vector, attention_probs @staticmethod def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu and not self.training: all_metrics.update(self._bleu.get_metric(reset=reset)) return all_metrics
class MultiTurnHred(Model): def __init__(self, vocab: Vocabulary, token_embedder: TextFieldEmbedder, document_encoder: Seq2VecEncoder, utterance_encoder: Seq2VecEncoder, context_encoder: Seq2SeqEncoder, beam_size: int = None, max_decoding_steps: int = 50, scheduled_sampling_ratio: float = 0., use_bleu: bool = True) -> None: super(MultiTurnHred, self).__init__(vocab) self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL) self._end_index = self.vocab.get_token_index(END_SYMBOL) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token) # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={pad_index, self._end_index, self._start_index}) else: self._bleu = None # At prediction time, we use a beam search to find the most likely sequence of target tokens. self._beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=self._beam_size) # At prediction time, we use a beam search to find the most likely sequence of target tokens. self._max_decoding_steps = max_decoding_steps # Dense embedding of word level tokens. self._token_embedder = token_embedder # Document word level encoder. self._document_encoder = document_encoder # Dialogue word level encoder. self._utterance_encoder = utterance_encoder # Sentence level encoder. self._context_encoder = context_encoder num_classes = self.vocab.get_vocab_size() document_output_dim = self._document_encoder.get_output_dim() utterance_output_dim = self._utterance_encoder.get_output_dim() context_output_dim = self._context_encoder.get_output_dim() decoder_output_dim = utterance_output_dim decoder_input_dim = token_embedder.get_output_dim() + document_output_dim + context_output_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. # TODO (pradeep): Do not hardcode decoder cell type. self._decoder_cell = GRUCell(decoder_input_dim, decoder_output_dim) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(decoder_output_dim, num_classes) @overrides def forward(self, # type: ignore document: Dict[str, torch.LongTensor], dialogue: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Make foward pass with decoder logic for producing the entire target sequence. """ state = {} # shape: (batch_size, document_length, embedding_dim) embedded_document = self._token_embedder(document) # shape: (batch_size, document_length) document_mask = util.get_text_field_mask(document) # shape: (batch_size, document_output_dim) import pdb; #pdb.set_trace() document_vec = self._document_encoder(embedded_document, document_mask) state['document_vec'] = document_vec # training and validation if dialogue is not None: # shape: (batch_size, sequence_num, sequence_length, embedding_dim) embedded_dialogue = self._token_embedder(dialogue) # shape: (batch_size, sequence_num, sequence_length) dialogue_mask = util.get_text_field_mask(dialogue, 1) state['dialogue_mask'] = dialogue_mask batch_size, sequence_num, sequence_length = dialogue_mask.size() # shape: (batch_size * sequence_num, utterance_output_dim) utterance_vec = self._utterance_encoder(embedded_dialogue. view(batch_size * sequence_num, sequence_length, -1), dialogue_mask.view(batch_size * sequence_num, -1)) # shape: (batch_size, sequence_num, utterance_output_dim) utterance_vec = utterance_vec.view(batch_size, sequence_num, -1) # shape: (batch_size, sequence_num, context_output_dim) context_vec = self._context_encoder(utterance_vec, dialogue_mask[:, :, 0]) # shape: (batch_size, sequence_num, utterance_output_dim) import pdb; #pdb.set_trace() decoder_hidden = torch.cat([utterance_vec.new_full((batch_size, 1, utterance_vec.size(-1)), fill_value=0.0), utterance_vec[:, :-1, :]], dim=1).contiguous() decoder_hidden = decoder_hidden.view(batch_size * sequence_num, -1) # shape: (batch_size * sequence_num, utterance_output_dim) state['decoder_hidden'] = decoder_hidden # shape: (batch_size, sequence_num, encoder_output_dim) # We reshape here to make it convenient to send to a torch GRUCell. context_vec = torch.cat([context_vec.new_full((batch_size, 1, context_vec.size(-1)), fill_value=0.0), context_vec[:, :-1, :]], dim=1).contiguous() # shape: (batch_size * sequence_num, encoder_output_dim) context_vec = context_vec.view(batch_size * sequence_num, -1) state['context_vec'] = context_vec output_dict = self._forward_loop(state, dialogue) if not self.training: state['decoder_hidden'] = decoder_hidden predictions = self._forward_beam_search(state) output_dict.update(predictions) if self._bleu: # shape: (batch_size * sequence_num, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"].view(batch_size * sequence_num, self._beam_size, -1) # shape: (batch_size * sequence_num, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._bleu(best_predictions, dialogue["tokens"].view(batch_size * sequence_num, -1)) return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ # shape: (batch_size, sequence_num, beam_size, num_decoding_steps) predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for instance_indices in predicted_indices: instance_predicted_tokens = [] for indices in instance_indices: if len(indices.shape) > 1: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [self.vocab.get_token_from_index(x) for x in indices] instance_predicted_tokens.append(predicted_tokens) all_predicted_tokens.append(instance_predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict """ def _encode_document(self, document: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]: # shape: (batch_size, sequence_num, sequence_length, embedding_dim) embedded_document = self._token_embedder(document) # shape: (batch_size, sequence_num, sequence_length) document_mask = util.get_text_field_mask(document, 1) batch_size, sequence_num, sequence_length = document_mask.size() # shape: (batch_size * sequence_num, encoder_output_dim) sequence_vec = self._sequence_encoder(embedded_document.view(batch_size*sequence_num, sequence_length, -1), document_mask.view(batch_size*sequence_num, -1)) # DEBUG: assert sequence_vec.size() == (batch_size * sequence_num, self._encoder_output_dim) # shape: (batch_size, sequence_num, encoder_output_dim) sequence_vec = sequence_vec.view(batch_size, sequence_num, -1) # DEBUG: assert sequence_vec.size() == (batch_size, sequence_num, self._encoder_output_dim) # shape: (batch_size, sequence_num, encoder_output_dim) context_encoder_output = self._context_encoder(sequence_vec, document_mask[:, :, 0]) # DEBUG: assert context_encoder_output.size() == (batch_size, sequence_num, self._encoder_output_dim) # shape: (batch_size, encoder_output_dim) document_vec = util.get_final_encoder_states(context_encoder_output, document_mask[:, :, 0]) # DEBUG: assert document_vec.size() == (batch_size, self._encoder_output_dim) return {'document_vec': document_vec} """ def _forward_loop(self, state: Dict[str, torch.Tensor], dialogue: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]: # shape: (batch_size, sequence_num, sequence_length) targets = dialogue["tokens"] batch_size, sequence_num, sequence_length = targets.size() num_decoding_steps = sequence_length - 1 # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. last_predictions = targets.new_full((batch_size * sequence_num,), fill_value=self._start_index) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): # shape: (batch_size * sequence_num, num_classes) output_projections, state = self._prepare_output_projections(last_predictions, state) # list of tensors, shape: (batch_size * sequence_num, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape: (batch_size * sequence_num, num_classes) class_probabilities = F.softmax(output_projections, dim=-1) # shape (predicted_classes): (batch_size * sequence_num,) _, predicted_classes = torch.max(class_probabilities, 1) # shape (predicted_classes): (batch_size * sequence_num,) last_predictions = predicted_classes # list of tensors, shape: (batch_size * sequence_num, 1) step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size * sequence_num, num_decoding_steps) predictions = torch.cat(step_predictions, 1) output_dict = {"predictions": predictions} # shape: (batch_size * sequence_num, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # Compute loss. # shape: (batch_size * sequence_num, sequence_length) target_mask = state["dialogue_mask"].view(batch_size * sequence_num, -1) targets = targets.view(batch_size * sequence_num, -1) loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss return output_dict def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size, sequence_num, sequence_length = state["dialogue_mask"].size() start_predictions = state["dialogue_mask"].new_full((batch_size * sequence_num,), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size * sequence_num, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size * sequence_num, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) all_top_k_predictions = all_top_k_predictions.view(batch_size, sequence_num, self._beam_size, -1) log_probabilities = log_probabilities.view(batch_size, sequence_num, -1) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def take_step(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(batch_size * sequence_num,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(batch_size * sequence_num, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ # shape: (batch_size * sequence_num, num_classes) output_projections, state = self._prepare_output_projections(last_predictions, state) # shape: (batch_size * sequence_num, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state def _prepare_output_projections(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # pylint: disable=line-too-long """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (batch_size * sequence_num, embedding_dim + document_output_dim + context_output_dim) batch_size = state['document_vec'].size(0) sequence_num = state['context_vec'].size(0) // batch_size embedded_input = torch.cat([self._token_embedder({'tokens': last_predictions}), state['document_vec'].unsqueeze(1).expand(-1, sequence_num, -1).view(batch_size * sequence_num, -1), state['context_vec']], dim=-1) # shape: (batch_size * sequence_num, decoder_output_dim) decoder_hidden = state['decoder_hidden'] # shape: (batch_size * sequence_num, embedding_dim) decoder_hidden = self._decoder_cell(embedded_input, decoder_hidden) state['decoder_hidden'] = decoder_hidden # shape: (batch_size * sequence_num, num_classes) output_projections = self._output_projection_layer(decoder_hidden) return output_projections, state @staticmethod def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ # shape: (batch_size * sequence_num, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size * sequence_num, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu and not self.training: all_metrics.update(self._bleu.get_metric(reset=reset)) return all_metrics
class SimpleSeq2Seq(Model): """ This `SimpleSeq2Seq` class is a `Model` which takes a sequence, encodes it, and then uses the encoded representations to decode another sequence. You can use this as the basis for a neural machine translation system, an abstractive summarization system, or any other common seq2seq problem. The model here is simple, but should be a decent starting place for implementing recent models for these tasks. # Parameters vocab : `Vocabulary`, required Vocabulary containing source and target vocabularies. They may be under the same namespace (`tokens`) or the target tokens can have a different namespace, in which case it needs to be specified as `target_namespace`. source_embedder : `TextFieldEmbedder`, required Embedder for source side sequences encoder : `Seq2SeqEncoder`, required The encoder of the "encoder/decoder" model max_decoding_steps : `int` Maximum length of decoded sequences. target_namespace : `str`, optional (default = `'tokens'`) If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. target_embedding_dim : `int`, optional (default = `'source_embedding_dim'`) You can specify an embedding dimensionality for the target side. If not, we'll use the same value as the source embedder's. target_pretrain_file : `str`, optional (default = `None`) Path to target pretrain embedding files target_decoder_layers : `int`, optional (default = `1`) Nums of layer for decoder attention : `Attention`, optional (default = `None`) If you want to use attention to get a dynamic summary of the encoder outputs at each step of decoding, this is the function used to compute similarity between the decoder hidden state and encoder outputs. beam_size : `int`, optional (default = `None`) Width of the beam for beam search. If not specified, greedy decoding is used. scheduled_sampling_ratio : `float`, optional (default = `0.`) At each timestep during training, we sample a random number between 0 and 1, and if it is not less than this value, we use the ground truth labels for the whole batch. Else, we use the predictions from the previous time step for the whole batch. If this value is 0.0 (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not using target side ground truth labels. See the following paper for more information: [Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al., 2015](https://arxiv.org/abs/1506.03099). use_bleu : `bool`, optional (default = `True`) If True, the BLEU metric will be calculated during validation. ngram_weights : `Iterable[float]`, optional (default = `(0.25, 0.25, 0.25, 0.25)`) Weights to assign to scores for each ngram size. """ def __init__( self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, max_decoding_steps: int, attention: Attention = None, beam_size: int = None, target_namespace: str = "tokens", target_embedding_dim: int = None, scheduled_sampling_ratio: float = 0.0, use_bleu: bool = True, bleu_ngram_weights: Iterable[float] = (0.25, 0.25, 0.25, 0.25), target_pretrain_file: str = None, target_decoder_layers: int = 1, ) -> None: super().__init__(vocab) self._target_namespace = target_namespace self._target_decoder_layers = target_decoder_layers self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) self._bleu = BLEU( bleu_ngram_weights, exclude_indices={ pad_index, self._end_index, self._start_index }, ) else: self._bleu = None # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source vocab tokens. self._source_embedder = source_embedder # Encodes the sequence of source embeddings into a sequence of hidden states. self._encoder = encoder num_classes = self.vocab.get_vocab_size(self._target_namespace) # Attention mechanism applied to the encoder output for each step. self._attention = attention # Dense embedding of vocab words in the target space. target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim( ) if not target_pretrain_file: self._target_embedder = Embedding( num_embeddings=num_classes, embedding_dim=target_embedding_dim) else: self._target_embedder = Embedding( embedding_dim=target_embedding_dim, pretrained_file=target_pretrain_file, vocab_namespace=self._target_namespace, vocab=self.vocab, ) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = self._encoder.get_output_dim() self._decoder_output_dim = self._encoder_output_dim if self._attention: # If using attention, a weighted average over encoder outputs will be concatenated # to the previous target embedding to form the input to the decoder at each # time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim else: # Otherwise, the input to the decoder is just the previous target embedding. self._decoder_input_dim = target_embedding_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. # TODO (pradeep): Do not hardcode decoder cell type. if self._target_decoder_layers > 1: self._decoder_cell = LSTM( self._decoder_input_dim, self._decoder_output_dim, self._target_decoder_layers, ) else: self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) def take_step(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], step: int) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. # Parameters last_predictions : `torch.Tensor` A tensor of shape `(group_size,)`, which gives the indices of the predictions during the last time step. state : `Dict[str, torch.Tensor]` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape `(group_size, *)`, where `*` can be any other number of dimensions. step : `int` The time step in beam search decoding. # Returns Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of `(log_probabilities, updated_state)`, where `log_probabilities` is a tensor of shape `(group_size, num_classes)` containing the predicted log probability of each class for the next step, for each item in the group, while `updated_state` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though `group_size` is not necessarily equal to `batch_size`, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections( last_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def forward( self, # type: ignore source_tokens: TextFieldTensors, target_tokens: TextFieldTensors = None, ) -> Dict[str, torch.Tensor]: """ Make foward pass with decoder logic for producing the entire target sequence. # Parameters source_tokens : `TextFieldTensors` The output of `TextField.as_array()` applied on the source `TextField`. This will be passed through a `TextFieldEmbedder` and then through an encoder. target_tokens : `TextFieldTensors`, optional (default = `None`) Output of `Textfield.as_array()` applied on target `TextField`. We assume that the target tokens are also represented as a `TextField`. # Returns `Dict[str, torch.Tensor]` """ state = self._encode(source_tokens) if target_tokens: state = self._init_decoder_state(state) # The `_forward_loop` decodes the input sequence and computes the loss during training # and validation. output_dict = self._forward_loop(state, target_tokens) else: output_dict = {} if not self.training: state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) if target_tokens and self._bleu: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._bleu(best_predictions, target_tokens["tokens"]["tokens"]) return output_dict @overrides def make_output_human_readable( self, output_dict: Dict[str, Any]) -> Dict[str, Any]: """ Finalize predictions. This method overrides `Model.make_output_human_readable`, which gets called after `Model.forward`, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the `forward` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called `predicted_tokens` to the `output_dict`. """ predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for top_k_predictions in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # we want top-k results. if len(top_k_predictions.shape) == 1: top_k_predictions = [top_k_predictions] batch_predicted_tokens = [] for indices in top_k_predictions: indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] batch_predicted_tokens.append(predicted_tokens) all_predicted_tokens.append(batch_predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _encode( self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_input = self._source_embedder(source_tokens) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = self._encoder(embedded_input, source_mask) return {"source_mask": source_mask, "encoder_outputs": encoder_outputs} def _init_decoder_state( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( state["encoder_outputs"], state["source_mask"], self._encoder.is_bidirectional(), ) # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = final_encoder_output # shape: (batch_size, decoder_output_dim) state["decoder_context"] = state["encoder_outputs"].new_zeros( batch_size, self._decoder_output_dim) if self._target_decoder_layers > 1: # shape: (num_layers, batch_size, decoder_output_dim) state["decoder_hidden"] = ( state["decoder_hidden"].unsqueeze(0).repeat( self._target_decoder_layers, 1, 1)) # shape: (num_layers, batch_size, decoder_output_dim) state["decoder_context"] = ( state["decoder_context"].unsqueeze(0).repeat( self._target_decoder_layers, 1, 1)) return state def _forward_loop( self, state: Dict[str, torch.Tensor], target_tokens: TextFieldTensors = None) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size()[0] if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"]["tokens"] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full((batch_size, ), fill_value=self._start_index, dtype=torch.long) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand( 1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size,) input_choices = last_predictions elif not target_tokens: # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] # shape: (batch_size, num_classes) output_projections, state = self._prepare_output_projections( input_choices, state) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape: (batch_size, num_classes) class_probabilities = F.softmax(output_projections, dim=-1) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(class_probabilities, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) output_dict = {"predictions": predictions} if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss return output_dict def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self._start_index, dtype=torch.long) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _prepare_output_projections( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (num_layers, group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (num_layers, group_size, decoder_output_dim) decoder_context = state["decoder_context"] # shape: (group_size, target_embedding_dim) embedded_input = self._target_embedder(last_predictions) if self._attention: # shape: (group_size, encoder_output_dim) if self._target_decoder_layers > 1: attended_input = self._prepare_attended_input( decoder_hidden[0], encoder_outputs, source_mask) else: attended_input = self._prepare_attended_input( decoder_hidden, encoder_outputs, source_mask) # shape: (group_size, decoder_output_dim + target_embedding_dim) decoder_input = torch.cat((attended_input, embedded_input), -1) else: # shape: (group_size, target_embedding_dim) decoder_input = embedded_input if self._target_decoder_layers > 1: # shape: (1, batch_size, target_embedding_dim) decoder_input = decoder_input.unsqueeze(0) # shape (decoder_hidden): (num_layers, batch_size, decoder_output_dim) # shape (decoder_context): (num_layers, batch_size, decoder_output_dim) # TODO (epwalsh): remove the autocast(False) once torch's AMP is working for LSTMCells. with torch.cuda.amp.autocast(False): _, (decoder_hidden, decoder_context) = self._decoder_cell( decoder_input.float(), (decoder_hidden.float(), decoder_context.float())) else: # shape (decoder_hidden): (batch_size, decoder_output_dim) # shape (decoder_context): (batch_size, decoder_output_dim) # TODO (epwalsh): remove the autocast(False) once torch's AMP is working for LSTMCells. with torch.cuda.amp.autocast(False): decoder_hidden, decoder_context = self._decoder_cell( decoder_input.float(), (decoder_hidden.float(), decoder_context.float())) state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context # shape: (group_size, num_classes) if self._target_decoder_layers > 1: output_projections = self._output_projection_layer( decoder_hidden[-1]) else: output_projections = self._output_projection_layer(decoder_hidden) return output_projections, state def _prepare_attended_input( self, decoder_hidden_state: torch.LongTensor = None, encoder_outputs: torch.LongTensor = None, encoder_outputs_mask: torch.BoolTensor = None, ) -> torch.Tensor: """Apply attention over encoder outputs and decoder state.""" # shape: (batch_size, max_input_sequence_length) input_weights = self._attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask) # shape: (batch_size, encoder_output_dim) attended_input = util.weighted_sum(encoder_outputs, input_weights) return attended_input @staticmethod def _get_loss( logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.BoolTensor, ) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of `targets` is expected to be greater than that of `logits` because the decoder does not need to compute the output corresponding to the last timestep of `targets`. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu and not self.training: all_metrics.update(self._bleu.get_metric(reset=reset)) return all_metrics default_predictor = "seq2seq"
imgs = imgs.cuda() caps = caps.cuda() # perform back-propagation over batch output_dict = model(imgs, caps) loss = output_dict['loss'] optimizer.zero_grad() loss.backward() optimizer.step() if epoch % 10 == 0: # evaluation every 10 epochs model.eval() bleu_eval = BLEU(exclude_indices={ vocab['<start>'], vocab['<end>'], vocab['<pad>'] }) with torch.no_grad(): for data_batch in eval_loader: # load the batch data imgs, caps = data_batch['image'], data_batch['caption'] imgs = imgs.cuda() output_dict = model(imgs) seq = output_dict['seq'] bleu_eval(predictions=seq, gold_targets=caps) bleu_score = bleu_eval.get_metric()['BLEU'] if bleu_score > best_bleu: best_bleu = bleu_score torch.save(model.state_dict(), model_store_path)
class BleuAutoRegressiveSeqDecoder(SeqDecoder): """ An autoregressive decoder that can be used for most seq2seq tasks. Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (`tokens`) or the target tokens can have a different namespace, in which case it needs to be specified as `target_namespace`. decoder_net : ``DecoderNet``, required Module that contains implementation of neural network for decoding output elements max_decoding_steps : ``int``, required Maximum length of decoded sequences. target_embedder : ``Embedding``, required Embedder for target tokens. target_namespace : ``str``, optional (default = 'tokens') If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. beam_size : ``int``, optional (default = 4) Width of the beam for beam search. tensor_based_metric : ``Metric``, optional (default = None) A metric to track on validation data that takes raw tensors when its called. This metric must accept two arguments when called: a batched tensor of predicted token indices, and a batched tensor of gold token indices. token_based_metric : ``Metric``, optional (default = None) A metric to track on validation data that takes lists of lists of tokens as input. This metric must accept two arguments when called, both of type `List[List[str]]`. The first is a predicted sequence for each item in the batch and the second is a gold sequence for each item in the batch. scheduled_sampling_ratio : ``float`` optional (default = 0) Defines ratio between teacher forced training and real output usage. If its zero (teacher forcing only) and `decoder_net`supports parallel decoding, we get the output predictions in a single forward pass of the `decoder_net`. """ def __init__( self, vocab: Vocabulary, decoder_net: DecoderNet, max_decoding_steps: int, target_embedder: Embedding, target_namespace: str = "tokens", tie_output_embedding: bool = False, scheduled_sampling_ratio: float = 0, label_smoothing_ratio: Optional[float] = None, beam_size: int = 4, tensor_based_metric: Metric = None, token_based_metric: Metric = None, bleu_exclude_tokens: List = [], ) -> None: super().__init__(target_embedder) self._vocab = vocab # Decodes the sequence of encoded hidden states into e new sequence of hidden states. self._decoder_net = decoder_net self._max_decoding_steps = max_decoding_steps self._target_namespace = target_namespace self._label_smoothing_ratio = label_smoothing_ratio # At prediction time, we use a beam search to find the most likely sequence of target tokens. # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self._vocab.get_token_index( START_SYMBOL, self._target_namespace ) self._end_index = self._vocab.get_token_index( END_SYMBOL, self._target_namespace ) self._beam_search = BeamSearch( self._end_index, max_steps=max_decoding_steps, beam_size=beam_size ) target_vocab_size = self._vocab.get_vocab_size(self._target_namespace) if ( self.target_embedder.get_output_dim() != self._decoder_net.target_embedding_dim ): raise ConfigurationError( "Target Embedder output_dim doesn't match decoder module's input." ) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear( self._decoder_net.get_output_dim(), target_vocab_size ) if tie_output_embedding: if ( self._output_projection_layer.weight.shape != self.target_embedder.weight.shape ): raise ConfigurationError( "Can't tie embeddings with output linear layer, due to shape mismatch" ) self._output_projection_layer.weight = self.target_embedder.weight # These metrics will be updated during training and validation if isinstance(tensor_based_metric, BLEU): pad_index = self._vocab.get_token_index( self._vocab._padding_token, self._target_namespace ) new_exclude_indices = set([pad_index]) for token in bleu_exclude_tokens: new_exclude_indices.add( self._vocab.get_token_index(token, self._target_namespace) ) new_exclude_indices.update(tensor_based_metric._exclude_indices) logger.info( f"Reconstruct BLEU to exclude indices {' '.join(map(str, new_exclude_indices))}" ) self._tensor_based_metric = BLEU( tensor_based_metric._ngram_weights, new_exclude_indices ) else: self._tensor_based_metric = tensor_based_metric self._token_based_metric = token_based_metric self._scheduled_sampling_ratio = scheduled_sampling_ratio def _forward_beam_search( self, state: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: """ Prepare inputs for the beam search, does beam search and returns beam search results. """ batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size,), fill_value=self._start_index ) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step ) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _forward_loss( self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] ) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] # Prepare embeddings for targets. They will be used as gold embeddings during decoder training # shape: (batch_size, max_target_sequence_length, embedding_dim) target_embedding = self.target_embedder(targets) # shape: (batch_size, max_target_batch_sequence_length) target_mask = util.get_text_field_mask(target_tokens) if self._scheduled_sampling_ratio == 0 and self._decoder_net.decodes_parallel: _, decoder_output = self._decoder_net( previous_state=state, previous_steps_predictions=target_embedding[:, :-1, :], encoder_outputs=encoder_outputs, source_mask=source_mask, previous_steps_mask=target_mask[:, :-1], ) # shape: (group_size, max_target_sequence_length, num_classes) logits = self._output_projection_layer(decoder_output) else: batch_size = source_mask.size()[0] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full( (batch_size,), fill_value=self._start_index ) # shape: (steps, batch_size, target_embedding_dim) steps_embeddings = torch.Tensor([]) step_logits: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if ( self.training and torch.rand(1).item() < self._scheduled_sampling_ratio ): # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size, steps, target_embedding_dim) state["previous_steps_predictions"] = steps_embeddings # shape: (batch_size, ) effective_last_prediction = last_predictions else: # shape: (batch_size, ) effective_last_prediction = targets[:, timestep] if timestep == 0: state["previous_steps_predictions"] = torch.Tensor([]) else: # shape: (batch_size, steps, target_embedding_dim) state["previous_steps_predictions"] = target_embedding[ :, :timestep ] # shape: (batch_size, num_classes) output_projections, state = self._prepare_output_projections( effective_last_prediction, state ) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(output_projections, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes # shape: (batch_size, 1, target_embedding_dim) last_predictions_embeddings = self.target_embedder( last_predictions ).unsqueeze(1) # This step is required, since we want to keep up two different prediction history: gold and real if steps_embeddings.shape[-1] == 0: # There is no previous steps, except for start vectors in ``last_predictions`` # shape: (group_size, 1, target_embedding_dim) steps_embeddings = last_predictions_embeddings else: # shape: (group_size, steps_count, target_embedding_dim) steps_embeddings = torch.cat( [steps_embeddings, last_predictions_embeddings], 1 ) # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) # TODO: We will be using beam search to get predictions for validation, but if beam size in 1 # we could consider taking the last_predictions here and building step_predictions # and use that instead of running beam search again, if performance in validation is taking a hit output_dict = {"loss": loss} return output_dict def _prepare_output_projections( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, steps_count, decoder_output_dim) previous_steps_predictions = state.get("previous_steps_predictions") # shape: (batch_size, 1, target_embedding_dim) last_predictions_embeddings = self.target_embedder(last_predictions).unsqueeze( 1 ) if ( previous_steps_predictions is None or previous_steps_predictions.shape[-1] == 0 ): # There is no previous steps, except for start vectors in ``last_predictions`` # shape: (group_size, 1, target_embedding_dim) previous_steps_predictions = last_predictions_embeddings else: # shape: (group_size, steps_count, target_embedding_dim) previous_steps_predictions = torch.cat( [previous_steps_predictions, last_predictions_embeddings], 1 ) decoder_state, decoder_output = self._decoder_net( previous_state=state, encoder_outputs=encoder_outputs, source_mask=source_mask, previous_steps_predictions=previous_steps_predictions, ) state["previous_steps_predictions"] = previous_steps_predictions # Update state with new decoder state, override previous state state.update(decoder_state) if self._decoder_net.decodes_parallel: decoder_output = decoder_output[:, -1, :] # shape: (group_size, num_classes) output_projections = self._output_projection_layer(decoder_output) return output_projections, state def _get_loss( self, logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor, ) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits( logits, relevant_targets, relevant_mask, label_smoothing=self._label_smoothing_ratio, ) def get_output_dim(self): return self._decoder_net.get_output_dim() def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections( last_predictions, state ) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if not self.training: if self._tensor_based_metric is not None: all_metrics.update( self._tensor_based_metric.get_metric(reset=reset) # type: ignore ) if self._token_based_metric is not None: all_metrics.update(self._token_based_metric.get_metric(reset=reset)) # type: ignore return all_metrics @overrides def forward( self, encoder_out: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None, ) -> Dict[str, torch.Tensor]: state = encoder_out if target_tokens: decoder_init_state = self._decoder_net.init_decoder_state(state) decoder_init_state.update(state) output_dict = self._forward_loss(decoder_init_state, target_tokens) else: output_dict = {} if not self.training: decoder_init_state = self._decoder_net.init_decoder_state(state) decoder_init_state.update(state) predictions = self._forward_beam_search(decoder_init_state) output_dict.update(predictions) if target_tokens: if self._tensor_based_metric is not None: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._tensor_based_metric( # type: ignore best_predictions, target_tokens["tokens"] ) if self._token_based_metric is not None: output_dict = self.post_process(output_dict) predicted_tokens = output_dict["predicted_tokens"] self._token_based_metric( # type: ignore predicted_tokens, self.indices_to_tokens(target_tokens["tokens"][:, 1:]), ) return output_dict @overrides def post_process( self, output_dict: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: """ This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ predicted_indices = output_dict["predictions"] all_predicted_tokens = self.indices_to_tokens(predicted_indices) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def indices_to_tokens(self, batch_indeces: numpy.ndarray) -> List[List[str]]: if not isinstance(batch_indeces, numpy.ndarray): batch_indeces = batch_indeces.detach().cpu().numpy() all_tokens = [] for indices in batch_indeces: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[: indices.index(self._end_index)] tokens = [ self._vocab.get_token_from_index(x, namespace=self._target_namespace) for x in indices ] all_tokens.append(tokens) return all_tokens
class VAE(Model): """ This ``VAE`` class is a :class:`Model` which implements a simple VAE as first described in https://arxiv.org/pdf/1511.06349.pdf (Bowman et al., 2015). Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (`tokens`) or the target tokens can have a different namespace, in which case it needs to be specified as `target_namespace`. variational_encoder : ``VariationalEncoder``, required The encoder model of which to pass the source tokens decoder : ``Model``, required The variational decoder model of which to pass the the latent variable latent_dim : ``int``, required The dimention of the latent, z vector. This is not necessarily the same size as the encoder output dim initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. """ def __init__( self, vocab: Vocabulary, variational_encoder: VariationalEncoder, decoder: Decoder, kl_weight: LossWeight, temperature: float = 1.0, initializer: InitializerApplicator = InitializerApplicator() ) -> None: super(VAE, self).__init__(vocab) self._encoder = variational_encoder self._decoder = decoder self._latent_dim = variational_encoder.latent_dim self._encoder_output_dim = self._encoder.get_encoder_output_dim() self._start_index = self.vocab.get_token_index(START_SYMBOL) self._end_index = self.vocab.get_token_index(END_SYMBOL) self._pad_index = self.vocab.get_token_index(self.vocab._padding_token) # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={ self._pad_index, self._end_index, self._start_index }) self._kl_metric = Average() self.kl_weight = kl_weight self._temperature = temperature initializer(self) @overrides def forward( self, # type: ignore source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Make forward pass for both training/validation/test time. """ encoder_outs = self._encoder(source_tokens) p_z = encoder_outs['prior'] q_z = encoder_outs['posterior'] kl_weight = self.kl_weight.get() if self.training: z = q_z.rsample() self.kl_weight.step() else: z = self._encoder.reparametrize(p_z, q_z, self._temperature) batch_size = z.size(0) kld = kl_divergence(q_z, p_z).sum() / batch_size self._kl_metric(kld) output_dict = {'z': z, 'predictions': source_tokens['tokens']} if not target_tokens: return output_dict # Do Decoding output_dict.update(self._decoder(z, target_tokens)) rec_loss = output_dict['loss'] kl_loss = kld * kl_weight output_dict['loss'] = rec_loss + kl_loss if not self.training: best_predictions = output_dict["predictions"] self._bleu(best_predictions, target_tokens["tokens"]) return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu and not self.training: all_metrics.update(self._bleu.get_metric(reset=reset)) all_metrics.update({'klw': float(self.kl_weight.get())}) all_metrics.update( {'kl': float(self._kl_metric.get_metric(reset=reset))}) return all_metrics def generate(self, num_to_sample: int = 1): cuda_device = self._get_prediction_device() prior_mean = nn_util.move_to_device( torch.zeros((num_to_sample, self._latent_dim)), cuda_device) prior_stddev = torch.ones_like(prior_mean) prior = Normal(prior_mean, prior_stddev) latent = prior.sample() generated = self._decoder.generate(latent) return self.decode(generated) @overrides # simple_seq2seq's decode def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index(x) for x in indices ] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict
class AttnSupSeq2Seq(Model): """ Adaptation of the ``SimpleSeq2Seq`` class in allennlp_models, with auxiliary attention-supervision loss Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (`tokens`) or the target tokens can have a different namespace, in which case it needs to be specified as `target_namespace`. source_embedder : ``TextFieldEmbedder``, required Embedder for source side sequences encoder : ``Seq2SeqEncoder``, required The encoder of the "encoder/decoder" model max_decoding_steps : ``int`` Maximum length of decoded sequences. target_namespace : ``str``, optional (default = 'target_tokens') If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. target_embedding_dim : ``int``, optional (default = source_embedding_dim) You can specify an embedding dimensionality for the target side. If not, we'll use the same value as the source embedder's. attention : ``Attention``, optional (default = None) If you want to use attention to get a dynamic summary of the encoder outputs at each step of decoding, this is the function used to compute similarity between the decoder hidden state and encoder outputs. attention_function: ``SimilarityFunction``, optional (default = None) This is if you want to use the legacy implementation of attention. This will be deprecated since it consumes more memory than the specialized attention modules. beam_size : ``int``, optional (default = None) Width of the beam for beam search. If not specified, greedy decoding is used. scheduled_sampling_ratio : ``float``, optional (default = 0.) At each timestep during training, we sample a random number between 0 and 1, and if it is not less than this value, we use the ground truth labels for the whole batch. Else, we use the predictions from the previous time step for the whole batch. If this value is 0.0 (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not using target side ground truth labels. See the following paper for more information: `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al., 2015 <https://arxiv.org/abs/1506.03099>`_. use_bleu : ``bool``, optional (default = True) If True, the BLEU metric will be calculated during validation. """ def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, max_decoding_steps: int, attention: Attention, schema_path: str = None, missing_alignment_int: int = 0, indexfield_padding_index: int = -1, beam_size: int = None, target_namespace: str = "tokens", target_embedding_dim: int = None, scheduled_sampling_ratio: float = 0., use_bleu: bool = True, emb_dropout: float = 0.0, dec_dropout: float = 0.0, attn_loss_lambda: float = 0.5, token_based_metric: Metric = None) -> None: super(AttnSupSeq2Seq, self).__init__(vocab) self._target_namespace = target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio self._indexfield_padding_index = indexfield_padding_index self._missing_alignment_int = missing_alignment_int # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={ pad_index, self._end_index, self._start_index }) else: self._bleu = None if token_based_metric: self._token_based_metric = token_based_metric else: self._token_based_metric = TokenSequenceAccuracy() # log attention supervision CE loss as a metric self._attn_sup_loss = Average() self._sql_metrics = schema_path is not None if self._sql_metrics: # SQL specific metrics: match between the templates free of schema constants, # and match between the schema constants self._schema_free_match = GlobalTemplAccuracy( schema_path=schema_path) self._kb_match = KnowledgeBaseConstsAccuracy( schema_path=schema_path) # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source vocab tokens. self._source_embedder = source_embedder self._emb_dropout = Dropout(p=emb_dropout) self._dec_dropout = Dropout(p=dec_dropout) self._attn_loss_lambda = attn_loss_lambda # Encodes the sequence of source embeddings into a sequence of hidden states. self._encoder = encoder num_classes = self.vocab.get_vocab_size(self._target_namespace) # Attention mechanism applied to the encoder output for each step. self._attention = attention self._attention._normalize = False # Dense embedding of vocab words in the target space. target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim( ) self._target_embedder = Embedding(num_classes, target_embedding_dim) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = self._encoder.get_output_dim() self._decoder_output_dim = self._encoder_output_dim # A weighted average over encoder outputs will be concatenated to the previous target embedding # to form the input to the decoder at each time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. # TODO (pradeep): Do not hardcode decoder cell type. self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes) _, output_projections, state = self._prepare_output_projections( last_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def forward_on_instances( self, instances: List[Instance]) -> List[Dict[str, numpy.ndarray]]: """ Takes a list of :class:`~allennlp.data.instance.Instance`s, converts that text into arrays using this model's :class:`Vocabulary`, passes those arrays through :func:`self.forward()` and :func:`self.decode()` (which by default does nothing) and returns the result. Before returning the result, we convert any ``torch.Tensors`` into numpy arrays and separate the batched output into a list of individual dicts per instance. Note that typically this will be faster on a GPU (and conditionally, on a CPU) than repeated calls to :func:`forward_on_instance`. Parameters ---------- instances : List[Instance], required The instances to run the model on. cuda_device : int, required The GPU device to use. -1 means use the CPU. Returns ------- A list of the models output for each instance. """ batch_size = len(instances) with torch.no_grad(): cuda_device = self._get_prediction_device() dataset = Batch(instances) dataset.index_instances(self.vocab) model_input = util.move_to_device(dataset.as_tensor_dict(), cuda_device) outputs = self.decode(self(**model_input)) instance_separated_output: List[Dict[str, numpy.ndarray]] = [ {} for _ in dataset.instances ] for name, output in list(outputs.items()): if isinstance(output, torch.Tensor): # NOTE(markn): This is a hack because 0-dim pytorch tensors are not iterable. # This occurs with batch size 1, because we still want to include the loss in that case. if output.dim() == 0: output = output.unsqueeze(0) if output.size(0) != batch_size: self._maybe_warn_for_unseparable_batches(name) continue output = output.detach().cpu().numpy() elif len(output) != batch_size: self._maybe_warn_for_unseparable_batches(name) continue for instance_output, batch_element in zip( instance_separated_output, output): instance_output[name] = batch_element for instance_output, instance_input in zip( instance_separated_output, instances): for field in instance_input.fields: try: instance_output[field] = instance_input.fields[ field].tokens except Exception as e: continue return instance_separated_output @overrides def forward( self, # type: ignore source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None, alignment_sequence: torch.Tensor = None ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Make foward pass with decoder logic for producing the entire target sequence. Parameters ---------- source_tokens : ``Dict[str, torch.LongTensor]`` The output of `TextField.as_array()` applied on the source `TextField`. This will be passed through a `TextFieldEmbedder` and then through an encoder. target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) Output of `Textfield.as_array()` applied on target `TextField`. We assume that the target tokens are also represented as a `TextField`. alignment_sequence : ``Dict[str, torch.LongTensor]``, optional (default = None) Output of `Textfield.as_array()` applied on alignemnet `TextField`. Returns ------- Dict[str, torch.Tensor] """ state = self._encode(source_tokens) if target_tokens: state = self._init_decoder_state(state) # Remove the trailing dimension (from ListField[ListField[IndexField]]). alignment_sequence = alignment_sequence.squeeze(-1) # The `_forward_loop` decodes the input sequence and computes the loss during training # and validation. output_dict = self._forward_loop(state, target_tokens, alignment_sequence) else: output_dict = {} if not self.training: state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) if target_tokens: if self._bleu: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._bleu(best_predictions, target_tokens["tokens"]) predicted_tokens = self.decode(output_dict)["predicted_tokens"] target_tokens_str = self.decode_target_tokens(target_tokens) if self._token_based_metric: self._token_based_metric(predicted_tokens, target_tokens_str) if self._sql_metrics: self._kb_match(predicted_tokens, target_tokens_str) self._schema_free_match(predicted_tokens, target_tokens_str) # In case of attention coverage mechanism, reset the coverage vector after every batch... try: self._attention.reset_coverage_vector() except Exception: pass return output_dict def decode_target_tokens(self, target_tokens): target_indices = target_tokens['tokens'].detach().cpu().numpy() target_tokens_output = [] for i in range(target_indices.shape[0]): cur_target_indices = target_indices[i] cur_target_indices = list(cur_target_indices) if self._end_index in cur_target_indices: cur_target_indices = cur_target_indices[:cur_target_indices. index(self._end_index)] if self._start_index in cur_target_indices: cur_target_indices = cur_target_indices[ cur_target_indices.index(self._start_index) + 1:] target_tokens_str = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in cur_target_indices ] target_tokens_output.append(target_tokens_str) return target_tokens_output @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _encode( self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_input = self._source_embedder(source_tokens) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = self._encoder(embedded_input, source_mask) encoder_outputs = self._emb_dropout(encoder_outputs) return { "source_mask": source_mask, "encoder_outputs": encoder_outputs, } def _init_decoder_state( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( state["encoder_outputs"], state["source_mask"], self._encoder.is_bidirectional()) # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = final_encoder_output # shape: (batch_size, decoder_output_dim) state["decoder_context"] = state["encoder_outputs"].new_zeros( batch_size, self._decoder_output_dim) return state def _forward_loop( self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None, alignment_sequence: torch.Tensor = None ) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size()[0] if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full((batch_size, ), fill_value=self._start_index) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] step_attn_weights: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand( 1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size,) input_choices = last_predictions elif not target_tokens: # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] # shape: (batch_size, num_classes) # shape: (batch_size, input_max_size) input_weights, output_projections, state = self._prepare_output_projections( input_choices, state) step_attn_weights.append(input_weights.unsqueeze(1)) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape: (batch_size, num_classes) class_probabilities = F.softmax(output_projections, dim=-1) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(class_probabilities, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) # shape: (batch_size, num_decoding_steps, max_input_sequence_length) attention_input_weights = torch.cat(step_attn_weights[:-1], 1) output_dict = { "predictions": predictions, 'attention_input_weights': attention_input_weights } if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # shape: (batch_size, num_decoding_steps, max_input_sequence_length) alignment_mask = self._get_alignment_mask(alignment_sequence) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) attn_sup_loss = self._get_attn_sup_loss(attention_input_weights, alignment_mask, alignment_sequence) self._attn_sup_loss(attn_sup_loss.detach().cpu().item()) output_dict["loss"] = loss + self._attn_loss_lambda * attn_sup_loss return output_dict def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _prepare_output_projections(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]: # pylint: disable=line-too-long """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (group_size, decoder_output_dim) decoder_context = state["decoder_context"] # shape: (group_size, target_embedding_dim) embedded_input = self._target_embedder(last_predictions) # shape: (group_size, encoder_output_dim) attended_input, input_weights = self._prepare_attended_input( decoder_hidden, encoder_outputs, source_mask) # shape: (group_size, decoder_output_dim + target_embedding_dim) decoder_input = torch.cat((attended_input, embedded_input), -1) decoder_input = self._dec_dropout(decoder_input) # shape (decoder_hidden): (batch_size, decoder_output_dim) # shape (decoder_context): (batch_size, decoder_output_dim) decoder_hidden, decoder_context = self._decoder_cell( decoder_input, (decoder_hidden, decoder_context)) state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context # shape: (group_size, num_classes) output_projections = self._output_projection_layer( self._dec_dropout(decoder_hidden)) return input_weights, output_projections, state def _prepare_attended_input( self, decoder_hidden_state: torch.LongTensor = None, encoder_outputs: torch.LongTensor = None, encoder_outputs_mask: torch.LongTensor = None ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply attention over encoder outputs and decoder state.""" # Ensure mask is also a FloatTensor. Or else the multiplication within # attention will complain. # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs_mask = encoder_outputs_mask.float() # shape: (batch_size, max_input_sequence_length) input_logits = self._attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask) # the attention mechanism returns the logits that are necessary for attention supervision loss, # so we normalize it here input_weights = masked_softmax(input_logits, encoder_outputs_mask) # shape: (batch_size, encoder_output_dim) attended_input = util.weighted_sum(encoder_outputs, input_weights) return attended_input, input_logits @staticmethod def _get_attn_sup_loss(attn_weights: torch.Tensor, alignment_mask: torch.Tensor, alignment_sequence: torch.Tensor) -> torch.Tensor: """ Compute the attention supervision CE loss. For each step, take the index of the aligned """ # shape: (batch_size, max_decoding_steps, max_input_seq_length attn_weights = attn_weights.float() alignment_sequence[alignment_sequence == -1] = 0 # for each attn_weights[batch_index, step_index, :] I want to choose the index of # alignment_sequence[batch_index, step_index] return util.sequence_cross_entropy_with_logits(attn_weights, alignment_sequence, alignment_mask) def _get_alignment_mask(self, alignment_sequence): """ The alignment mask includes the target mask + mask on steps that don't have alignment shape: batch_size, max_steps, max_input """ pad_mask = alignment_sequence != self._indexfield_padding_index missing_mask = alignment_sequence != self._missing_alignment_int return pad_mask * missing_mask @staticmethod def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if not self.training: if self._bleu: all_metrics.update(self._bleu.get_metric(reset=reset)) all_metrics.update( self._token_based_metric.get_metric(reset=reset)) if self._sql_metrics: all_metrics.update(self._kb_match.get_metric(reset=reset)) all_metrics.update( self._schema_free_match.get_metric(reset=reset)) all_metrics['attn_sup_loss'] = self._attn_sup_loss.get_metric( reset=reset) return all_metrics
class Bart(Model): """ BART model from the paper "BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension" (https://arxiv.org/abs/1910.13461). The Bart model here uses a language modeling head and thus can be used for text generation. # Parameters model_name : `str`, required Name of the pre-trained BART model to use. Available options can be found in `transformers.models.bart.modeling_bart.BART_PRETRAINED_MODEL_ARCHIVE_MAP`. vocab : `Vocabulary`, required Vocabulary containing source and target vocabularies. beam_search : `Lazy[BeamSearch]`, optional (default = `Lazy(BeamSearch)`) This is used to during inference to select the tokens of the decoded output sequence. indexer : `PretrainedTransformerIndexer`, optional (default = `None`) Indexer to be used for converting decoded sequences of ids to to sequences of tokens. encoder : `Seq2SeqEncoder`, optional (default = `None`) Encoder to used in BART. By default, the original BART encoder is used. """ def __init__( self, model_name: str, vocab: Vocabulary, beam_search: Lazy[BeamSearch] = Lazy(BeamSearch), indexer: PretrainedTransformerIndexer = None, encoder: Seq2SeqEncoder = None, **kwargs, ): super().__init__(vocab) self.bart = BartForConditionalGeneration.from_pretrained(model_name) self._indexer = indexer or PretrainedTransformerIndexer(model_name, namespace="tokens") self._start_id = self.bart.config.bos_token_id # CLS self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id self._end_id = self.bart.config.eos_token_id # SEP self._pad_id = self.bart.config.pad_token_id # PAD # At prediction time, we'll use a beam search to find the best target sequence. # For backwards compatibility, check if beam_size or max_decoding_steps were passed in as # kwargs. If so, update the BeamSearch object before constructing and raise a DeprecationWarning deprecation_warning = ( "The parameter {} has been deprecated." " Provide this parameter as argument to beam_search instead." ) beam_search_extras = {} if "beam_size" in kwargs: beam_search_extras["beam_size"] = kwargs["beam_size"] warnings.warn(deprecation_warning.format("beam_size"), DeprecationWarning) if "max_decoding_steps" in kwargs: beam_search_extras["max_steps"] = kwargs["max_decoding_steps"] warnings.warn(deprecation_warning.format("max_decoding_steps"), DeprecationWarning) self._beam_search = beam_search.construct( end_index=self._end_id, vocab=self.vocab, **beam_search_extras ) self._rouge = ROUGE(exclude_indices={self._start_id, self._pad_id, self._end_id}) self._bleu = BLEU(exclude_indices={self._start_id, self._pad_id, self._end_id}) # Replace bart encoder with given encoder. We need to extract the two embedding layers so that # we can use them in the encoder wrapper if encoder is not None: assert ( encoder.get_input_dim() == encoder.get_output_dim() == self.bart.config.hidden_size ) self.bart.model.encoder = _BartEncoderWrapper( encoder, self.bart.model.encoder.embed_tokens, self.bart.model.encoder.embed_positions, ) def forward( self, source_tokens: TextFieldTensors, target_tokens: TextFieldTensors = None ) -> Dict[str, torch.Tensor]: """ Performs the forward step of Bart. # Parameters source_tokens : `TextFieldTensors`, required The source tokens for the encoder. We assume they are stored under the `tokens` key. target_tokens : `TextFieldTensors`, optional (default = `None`) The target tokens for the decoder. We assume they are stored under the `tokens` key. If no target tokens are given, the source tokens are shifted to the right by 1. # Returns `Dict[str, torch.Tensor]` During training, this dictionary contains the `decoder_logits` of shape `(batch_size, max_target_length, target_vocab_size)` and the `loss`. During inference, it contains `predictions` of shape `(batch_size, max_decoding_steps)` and `log_probabilities` of shape `(batch_size,)`. """ inputs = source_tokens targets = target_tokens input_ids, input_mask = inputs["tokens"]["token_ids"], inputs["tokens"]["mask"] outputs = {} # If no targets are provided, then shift input to right by 1. Bart already does this internally # but it does not use them for loss calculation. if targets is not None: target_ids, target_mask = targets["tokens"]["token_ids"], targets["tokens"]["mask"] else: target_ids = input_ids[:, 1:] target_mask = input_mask[:, 1:] if self.training: bart_outputs = self.bart( input_ids=input_ids, attention_mask=input_mask, decoder_input_ids=target_ids[:, :-1].contiguous(), decoder_attention_mask=target_mask[:, :-1].contiguous(), use_cache=False, return_dict=True, ) outputs["decoder_logits"] = bart_outputs.logits # The BART paper mentions label smoothing of 0.1 for sequence generation tasks outputs["loss"] = sequence_cross_entropy_with_logits( bart_outputs.logits, cast(torch.LongTensor, target_ids[:, 1:].contiguous()), cast(torch.BoolTensor, target_mask[:, 1:].contiguous()), label_smoothing=0.1, average="token", ) else: # Use decoder start id and start of sentence to start decoder initial_decoder_ids = torch.tensor( [[self._decoder_start_id]], dtype=input_ids.dtype, device=input_ids.device, ).repeat(input_ids.shape[0], 1) inital_state = { "input_ids": input_ids, "input_mask": input_mask, } beam_result = self._beam_search.search( initial_decoder_ids, inital_state, self.take_step ) predictions = beam_result[0] max_pred_indices = ( beam_result[1].argmax(dim=-1).view(-1, 1, 1).expand(-1, -1, predictions.shape[-1]) ) predictions = predictions.gather(dim=1, index=max_pred_indices).squeeze(dim=1) self._rouge(predictions, target_ids) self._bleu(predictions, target_ids) outputs["predictions"] = predictions outputs["log_probabilities"] = ( beam_result[1].gather(dim=-1, index=max_pred_indices[..., 0]).squeeze(dim=-1) ) self.make_output_human_readable(outputs) return outputs @staticmethod def _decoder_cache_to_dict(decoder_cache: DecoderCacheType) -> Dict[str, torch.Tensor]: cache_dict = {} for layer_index, layer_cache in enumerate(decoder_cache): # Each layer caches the key and value tensors for its self-attention and cross-attention. # Hence the `layer_cache` tuple has 4 elements. assert len(layer_cache) == 4 for tensor_index, tensor in enumerate(layer_cache): key = f"decoder_cache_{layer_index}_{tensor_index}" cache_dict[key] = tensor return cache_dict def _dict_to_decoder_cache(self, cache_dict: Dict[str, torch.Tensor]) -> DecoderCacheType: decoder_cache = [] for layer_index in range(len(self.bart.model.decoder.layers)): base_key = f"decoder_cache_{layer_index}_" layer_cache = ( cache_dict[base_key + "0"], cache_dict[base_key + "1"], cache_dict[base_key + "2"], cache_dict[base_key + "3"], ) decoder_cache.append(layer_cache) assert decoder_cache return tuple(decoder_cache) def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], step: int ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take step during beam search. # Parameters last_predictions : `torch.Tensor` The predicted token ids from the previous step. Shape: `(group_size,)` state : `Dict[str, torch.Tensor]` State required to generate next set of predictions step : `int` The time step in beam search decoding. # Returns `Tuple[torch.Tensor, Dict[str, torch.Tensor]]` A tuple containing logits for the next tokens of shape `(group_size, target_vocab_size)` and an updated state dictionary. """ if len(last_predictions.shape) == 1: last_predictions = last_predictions.unsqueeze(-1) decoder_cache = None decoder_cache_dict = { k: state[k].contiguous() for k in state if k not in {"input_ids", "input_mask", "encoder_states"} } if len(decoder_cache_dict) != 0: decoder_cache = self._dict_to_decoder_cache(decoder_cache_dict) encoder_outputs = (state["encoder_states"],) if "encoder_states" in state else None outputs = self.bart( input_ids=state["input_ids"] if encoder_outputs is None else None, attention_mask=state["input_mask"], encoder_outputs=encoder_outputs, decoder_input_ids=last_predictions, past_key_values=decoder_cache, use_cache=True, return_dict=True, ) logits = outputs.logits[:, -1, :] log_probabilities = F.log_softmax(logits, dim=-1) decoder_cache = outputs.past_key_values if decoder_cache is not None: decoder_cache_dict = self._decoder_cache_to_dict(decoder_cache) state.update(decoder_cache_dict) state["encoder_states"] = outputs.encoder_last_hidden_state return log_probabilities, state def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: """ # Parameters output_dict : `Dict[str, torch.Tensor]` A dictionary containing a batch of predictions with key `predictions`. The tensor should have shape `(batch_size, max_sequence_length)` # Returns `Dict[str, Any]` Original `output_dict` with an additional `predicted_tokens` key that maps to a list of lists of tokens. """ predictions = output_dict["predictions"] predicted_tokens = [None] * predictions.shape[0] for i in range(predictions.shape[0]): predicted_tokens[i] = self._indexer.indices_to_tokens( {"token_ids": predictions[i].tolist()}, self.vocab, ) output_dict["predicted_tokens"] = predicted_tokens # type: ignore output_dict["predicted_text"] = self._indexer._tokenizer.batch_decode( predictions.tolist(), skip_special_tokens=True ) return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics: Dict[str, float] = {} if not self.training: metrics.update(self._rouge.get_metric(reset=reset)) metrics.update(self._bleu.get_metric(reset=reset)) return metrics default_predictor = "seq2seq"
class FactParaphraseSeq2Seq(Model): """ Given facts and dialog acts, it generates the paraphrased message. TODO: add dialog & dialog acts history This implementation is based off the default SimpleSeq2Seq model, which takes a sequence, encodes it, and then uses the encoded representations to decode another sequence. """ def __init__( self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, source_encoder: Seq2SeqEncoder, max_decoding_steps: int, dialog_acts_encoder: FeedForward = None, attention: Attention = None, attention_function: SimilarityFunction = None, n_dialog_acts: int = None, beam_size: int = None, target_namespace: str = "tokens", target_embedding_dim: int = None, scheduled_sampling_ratio: float = 0.0, use_bleu: bool = True, use_dialog_acts: bool = True, regularizers: Optional[RegularizerApplicator] = None, ) -> None: super().__init__(vocab, regularizers) self._target_namespace = target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first # timestep of decoding, and end symbol as a way to indicate the end # of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) self._bleu = BLEU(exclude_indices={ pad_index, self._end_index, self._start_index }) else: self._bleu = None # At prediction time, we use a beam search to find the most # likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source (Facts) vocab tokens. self._source_embedder = source_embedder # Encodes the sequence of source embeddings into a sequence of hidden states. self._source_encoder = source_encoder if use_dialog_acts: # Dense embedding of dialog acts. da_embedding_dim = dialog_acts_encoder.get_input_dim() self._dialog_acts_embedder = EmbeddingBag(n_dialog_acts, da_embedding_dim) # Encodes dialog acts self._dialog_acts_encoder = dialog_acts_encoder else: self._dialog_acts_embedder = None self._dialog_acts_encoder = None num_classes = self.vocab.get_vocab_size(self._target_namespace) # Attention mechanism applied to the encoder output for each step. if attention: if attention_function: raise ConfigurationError( "You can only specify an attention module or an " "attention function, but not both.") self._attention = attention elif attention_function: self._attention = LegacyAttention(attention_function) else: self._attention = None # Dense embedding of vocab words in the target space. target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim( ) self._target_embedder = Embedding(num_classes, target_embedding_dim) # Decoder output dim needs to be the same as the encoder output dim # since we initialize the hidden state of the decoder with the final # hidden state of the encoder. self._encoder_output_dim = self._source_encoder.get_output_dim() if use_dialog_acts: self._merge_encoder = Sequential( Linear( self._source_encoder.get_output_dim() + self._dialog_acts_encoder.get_output_dim(), self._encoder_output_dim, )) self._decoder_output_dim = self._encoder_output_dim if self._attention: # If using attention, a weighted average over encoder outputs will # be concatenated to the previous target embedding to form the input # to the decoder at each time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim else: # Otherwise, the input to the decoder is just the previous target embedding. self._decoder_input_dim = target_embedding_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. # TODO (pradeep): Do not hardcode decoder cell type. self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. """ # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections( last_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def forward( self, # type: ignore source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None, dialog_acts: Optional[torch.Tensor] = None, sender: Optional[torch.Tensor] = None, metadata: Optional[Dict] = None, ) -> Dict[str, torch.Tensor]: """ Make foward pass with decoder logic for producing the entire target sequence. """ source_state, dialog_acts_state = self._encode(source_tokens, dialog_acts) if target_tokens: state = self._init_decoder_state(source_state, dialog_acts_state) # The `_forward_loop` decodes the input sequence and # computes the loss during training and validation. output_dict = self._forward_loop(state, target_tokens) else: output_dict = {} if not self.training: state = self._init_decoder_state(source_state, dialog_acts_state) predictions = self._forward_beam_search(state) output_dict.update(predictions) if target_tokens and self._bleu: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._bleu(best_predictions, target_tokens["tokens"]) return output_dict @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. """ predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence # in the batch but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _encode( self, source_tokens: Dict[str, torch.Tensor], dialog_acts: torch.Tensor = None ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: # Encode source tokens source_state = self._encode_source_tokens(source_tokens) # Encode dialog acts if self._dialog_acts_encoder: dialog_acts_state = self._encode_dialog_acts(dialog_acts) else: dialog_acts_state = None return (source_state, dialog_acts_state) def _encode_source_tokens( self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_input = self._source_embedder(source_tokens) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = self._source_encoder(embedded_input, source_mask) return {"source_mask": source_mask, "encoder_outputs": encoder_outputs} def _encode_dialog_acts(self, dialog_acts: torch.Tensor) -> torch.Tensor: # shape: (batch_size, dialog_acts_embeddings_size) embedded_dialog_acts = self._dialog_acts_embedder(dialog_acts) # shape: (batch_size, dim_encoder) dialog_acts_state = self._dialog_acts_encoder(embedded_dialog_acts) return dialog_acts_state def _init_decoder_state( self, source_state: Dict[str, torch.Tensor], dialog_acts_state: torch.Tensor = None, ) -> Dict[str, torch.Tensor]: batch_size = source_state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( source_state["encoder_outputs"], source_state["source_mask"], self._source_encoder.is_bidirectional(), ) # Condition the source tokens state with dialog acts state if self._dialog_acts_encoder: final_encoder_output = self._merge_encoder( torch.cat([final_encoder_output, dialog_acts_state], dim=1)) # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) source_state["decoder_hidden"] = final_encoder_output # shape: (batch_size, decoder_output_dim) source_state["decoder_context"] = source_state[ "encoder_outputs"].new_zeros(batch_size, self._decoder_output_dim) return source_state def _forward_loop( self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None, ) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size()[0] if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full((batch_size, ), fill_value=self._start_index) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand( 1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of # 1 - _scheduled_sampling_ratio during training. # shape: (batch_size,) input_choices = last_predictions elif not target_tokens: # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] # shape: (batch_size, num_classes) output_projections, state = self._prepare_output_projections( input_choices, state) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape: (batch_size, num_classes) class_probabilities = F.softmax(output_projections, dim=-1) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(class_probabilities, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) output_dict = {"predictions": predictions} if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss return output_dict def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _prepare_output_projections( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (group_size, decoder_output_dim) decoder_context = state["decoder_context"] # shape: (group_size, target_embedding_dim) embedded_input = self._target_embedder(last_predictions) if self._attention: # shape: (group_size, encoder_output_dim) attended_input = self._prepare_attended_input( decoder_hidden, encoder_outputs, source_mask) # shape: (group_size, decoder_output_dim + target_embedding_dim) decoder_input = torch.cat((attended_input, embedded_input), -1) else: # shape: (group_size, target_embedding_dim) decoder_input = embedded_input # shape (decoder_hidden): (batch_size, decoder_output_dim) # shape (decoder_context): (batch_size, decoder_output_dim) decoder_hidden, decoder_context = self._decoder_cell( decoder_input, (decoder_hidden, decoder_context)) state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context # shape: (group_size, num_classes) output_projections = self._output_projection_layer(decoder_hidden) return output_projections, state def _prepare_attended_input( self, decoder_hidden_state: torch.LongTensor = None, encoder_outputs: torch.LongTensor = None, encoder_outputs_mask: torch.LongTensor = None, ) -> torch.Tensor: """Apply attention over encoder outputs and decoder state.""" # Ensure mask is also a FloatTensor. Or else the multiplication within # attention will complain. # shape: (batch_size, max_input_sequence_length) encoder_outputs_mask = encoder_outputs_mask.float() # shape: (batch_size, max_input_sequence_length) input_weights = self._attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask) # shape: (batch_size, encoder_output_dim) attended_input = util.weighted_sum(encoder_outputs, input_weights) return attended_input @staticmethod def _get_loss( logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor, ) -> torch.Tensor: # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu and not self.training: all_metrics.update(self._bleu.get_metric(reset=reset)) return all_metrics
class MaskedCopyNet(Model): def __init__(self, vocab: Vocabulary, embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, max_decoding_steps: int, attention: Attention = None, mask_embedder: TextFieldEmbedder = None, mask_attention: Attention = None, beam_size: int = None, target_namespace: str = "tokens", scheduled_sampling_ratio: float = 0., use_bleu: bool = True) -> None: super().__init__(vocab) self._target_namespace = target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) self._bleu = BLEU(exclude_indices={pad_index, self._end_index, self._start_index}) else: self._bleu = None # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source vocab tokens. self._embedder = embedder self._mask_embedder = mask_embedder # Encodes the sequence of source embeddings into a sequence of hidden states. self._encoder = encoder num_classes = self.vocab.get_vocab_size(self._target_namespace) # Attention mechanism applied to the encoder output for each step. self._attention = attention self._mask_attention = mask_attention # Dense embedding of vocab words in the target space. target_embedding_dim = self._embedder.get_output_dim() # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = self._encoder.get_output_dim() self._decoder_output_dim = self._encoder_output_dim if self._attention: # If using attention, a weighted average over encoder outputs will be concatenated # to the previous target embedding to form the input to the decoder at each # time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim else: # Otherwise, the input to the decoder is just the previous target embedding. self._decoder_input_dim = target_embedding_dim if self._mask_attention: self._decoder_input_dim += self._mask_embedder.get_output_dim() # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) def take_step(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections(last_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def forward(self, # type: ignore source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None, mask_tokens: Dict[str, torch.LongTensor] = None, **kwargs) -> Dict[str, torch.Tensor]: del kwargs assert mask_tokens is not None or self._mask_embedder is None, \ 'You must pass `mask_tokens` when `mask_embedder` is not None' state = self.encode(source_tokens, mask_tokens) if target_tokens: state = self.init_decoder_state(state) output_dict = self._forward_loop(state, target_tokens) else: output_dict = {} if not self.training: state = self.init_decoder_state(state) predictions = self.beam_search(state) output_dict.update(predictions) if target_tokens and self._bleu: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._bleu(best_predictions, target_tokens["tokens"]) return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for i, indices in enumerate(predicted_indices): curr_predictions = [] for ind in indices: ind = list(ind) # Collect indices till the first end_symbol if self._end_index in ind: ind = ind[:ind.index(self._end_index)] predicted_tokens = [self.vocab.get_token_from_index(x, namespace=self._target_namespace) for x in ind] curr_predictions.append(predicted_tokens) all_predicted_tokens.append(curr_predictions) output_dict["predicted_tokens"] = all_predicted_tokens # [batch_size, k, num_decoding_steps] return output_dict def encode(self, source_tokens: Dict[str, torch.Tensor], mask_tokens: Dict[str, torch.Tensor] = None) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_input = self._embedder(source_tokens) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = self._encoder(embedded_input, source_mask) state = { "source_mask": source_mask, "encoder_outputs": encoder_outputs } if mask_tokens is not None and self._mask_embedder is not None: embedded_input = self._mask_embedder(mask_tokens) masker_mask = util.get_text_field_mask(mask_tokens) state.update( { "mask_source_mask": masker_mask, "mask_encoder_outputs": embedded_input } ) return state def init_decoder_state(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( state["encoder_outputs"], state["source_mask"], self._encoder.is_bidirectional()) # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = final_encoder_output # shape: (batch_size, decoder_output_dim) state["decoder_context"] = state["encoder_outputs"].new_zeros(batch_size, self._decoder_output_dim) return state def _forward_loop(self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size()[0] if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full((batch_size,), fill_value=self._start_index) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size,) input_choices = last_predictions elif not target_tokens: # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] # shape: (batch_size, num_classes) output_projections, state = self._prepare_output_projections(input_choices, state) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape: (batch_size, num_classes) class_probabilities = F.softmax(output_projections, dim=-1) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(class_probabilities, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) output_dict = {"predictions": predictions} if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss return output_dict def beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full((batch_size,), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _prepare_output_projections(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (group_size, decoder_output_dim) decoder_context = state["decoder_context"] # shape: (group_size, target_embedding_dim) embedded_input = self._embedder({self._target_namespace: last_predictions}) if self._attention: # shape: (group_size, encoder_output_dim) attended_input = self._prepare_attended_input(decoder_hidden, encoder_outputs, source_mask) # shape: (group_size, decoder_output_dim + target_embedding_dim) decoder_input = torch.cat((attended_input, embedded_input), -1) else: # shape: (group_size, target_embedding_dim) decoder_input = embedded_input if self._mask_attention and self._mask_embedder: mask_encoder_outputs = state["mask_encoder_outputs"] mask_source_mask = state["mask_source_mask"] mask_attended_input = self._prepare_mask_attended_input( decoder_hidden, mask_encoder_outputs, mask_source_mask ) decoder_input = torch.cat((decoder_input, mask_attended_input), -1) # shape (decoder_hidden): (batch_size, decoder_output_dim) # shape (decoder_context): (batch_size, decoder_output_dim) decoder_hidden, decoder_context = self._decoder_cell( decoder_input, (decoder_hidden, decoder_context)) state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context # shape: (group_size, num_classes) output_projections = self._output_projection_layer(decoder_hidden) return output_projections, state def _prepare_attended_input(self, decoder_hidden_state: torch.LongTensor = None, encoder_outputs: torch.LongTensor = None, encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor: encoder_outputs_mask = encoder_outputs_mask.float() input_weights = self._attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask) attended_input = util.weighted_sum(encoder_outputs, input_weights) return attended_input def _prepare_mask_attended_input(self, decoder_hidden_state: torch.LongTensor = None, mask_encoder_outputs: torch.LongTensor = None, mask_encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor: encoder_outputs_mask = mask_encoder_outputs_mask.float() input_weights = self._mask_attention(decoder_hidden_state, mask_encoder_outputs, encoder_outputs_mask) attended_input = util.weighted_sum(mask_encoder_outputs, input_weights) return attended_input @staticmethod def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.Tensor: # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu and not self.training: all_metrics.update(self._bleu.get_metric(reset=reset)) return all_metrics
class MyTransformer(Model): def __init__( self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, transformer: Dict, max_decoding_steps: int, target_namespace: str, target_embedder: TextFieldEmbedder = None, use_bleu: bool = True, ) -> None: super().__init__(vocab) self._target_namespace = target_namespace self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) self._pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) if use_bleu: self._bleu = BLEU(exclude_indices={ self._pad_index, self._end_index, self._start_index }) else: self._bleu = None self._seq_acc = SequenceAccuracy() self._max_decoding_steps = max_decoding_steps self._source_embedder = source_embedder self._ndim = transformer["d_model"] self.pos_encoder = PositionalEncoding(self._ndim, transformer["dropout"]) num_classes = self.vocab.get_vocab_size(self._target_namespace) self._transformer = Transformer(**transformer) self._transformer.apply(inplace_relu) if target_embedder is None: self._target_embedder = self._source_embedder else: self._target_embedder = target_embedder self._output_projection_layer = Linear(self._ndim, num_classes) def _get_mask(self, meta_data): mask = torch.zeros(1, len(meta_data), self.vocab.get_vocab_size( self._target_namespace)).float() for bidx, md in enumerate(meta_data): for k, v in self.vocab._token_to_index[ self._target_namespace].items(): if 'position' in k and k not in md['avail_pos']: mask[:, bidx, v] = float('-inf') return mask def generate_square_subsequent_mask(self, sz): mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == False, float('-inf')).masked_fill( mask == True, float(0.0)) return mask @overrides def forward( self, source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None, meta_data: Any = None, ) -> Dict[str, torch.Tensor]: src, src_key_padding_mask = self._encode(self._source_embedder, source_tokens) memory = self._transformer.encoder( src, src_key_padding_mask=src_key_padding_mask) if meta_data is not None: target_vocab_mask = self._get_mask(meta_data) target_vocab_mask = target_vocab_mask.to(memory.device) else: target_vocab_mask = None output_dict = {} targets = None if target_tokens: targets = target_tokens["tokens"][:, 1:] target_mask = (util.get_text_field_mask({"tokens": targets}) == 1) assert targets.size(1) <= self._max_decoding_steps if self.training and target_tokens: tgt, tgt_key_padding_mask = self._encode( self._target_embedder, {"tokens": target_tokens["tokens"][:, :-1]}) tgt_mask = self.generate_square_subsequent_mask(tgt.size(0)).to( memory.device) output = self._transformer.decoder( tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=src_key_padding_mask) logits = self._output_projection_layer(output) if target_vocab_mask is not None: logits += target_vocab_mask class_probabilities = F.softmax(logits.detach(), dim=-1) _, predictions = torch.max(class_probabilities, -1) logits = logits.transpose(0, 1) loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss else: assert self.training is False output_dict["loss"] = torch.tensor(0.0).to(memory.device) if targets is not None: max_target_len = targets.size(1) else: max_target_len = None predictions, class_probabilities = self._decoder_step_by_step( memory, src_key_padding_mask, target_vocab_mask, max_target_len=max_target_len) predictions = predictions.transpose(0, 1) output_dict["predictions"] = predictions output_dict["class_probabilities"] = class_probabilities.transpose( 0, 1) if target_tokens: with torch.no_grad(): best_predictions = output_dict["predictions"] if self._bleu: self._bleu(best_predictions, targets) batch_size = targets.size(0) max_sz = max(best_predictions.size(1), targets.size(1), target_mask.size(1)) best_predictions_ = torch.zeros(batch_size, max_sz).to(memory.device) best_predictions_[:, :best_predictions. size(1)] = best_predictions targets_ = torch.zeros(batch_size, max_sz).to(memory.device) targets_[:, :targets.size(1)] = targets.cpu() target_mask_ = torch.zeros(batch_size, max_sz).to(memory.device) target_mask_[:, :target_mask.size(1)] = target_mask self._seq_acc(best_predictions_.unsqueeze(1), targets_, target_mask_) return output_dict @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): # shape: (batch_size, num_decoding_steps) predicted_indices = predicted_indices.detach().cpu().numpy() # class_probabilities = output_dict["class_probabilities"].detach().cpu() # sample_predicted_indices = [] # for cp in class_probabilities: # sample = torch.multinomial(cp, num_samples=1) # sample_predicted_indices.append(sample) # # shape: (batch_size, num_decoding_steps, num_samples) # sample_predicted_indices = torch.stack(sample_predicted_indices) all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _encode( self, embedder: TextFieldEmbedder, tokens: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: src = embedder(tokens) * math.sqrt(self._ndim) src = src.transpose(0, 1) src = self.pos_encoder(src) mask = util.get_text_field_mask(tokens) mask = (mask == 0) return src, mask def _decoder_step_by_step( self, memory: torch.Tensor, memory_key_padding_mask: torch.Tensor, target_vocab_mask: torch.Tensor = None, max_target_len: int = None) -> Tuple[torch.Tensor, torch.Tensor]: batch_size = memory.size(1) if getattr(self, "target_limit_decode_steps", False) and max_target_len is not None: num_decoding_steps = min(self._max_decoding_steps, max_target_len) print('decoding steps: ', num_decoding_steps) else: num_decoding_steps = self._max_decoding_steps last_predictions = memory.new_full( (batch_size, ), fill_value=self._start_index).long() step_predictions: List[torch.Tensor] = [] all_predicts = memory.new_full((batch_size, num_decoding_steps), fill_value=0).long() for timestep in range(num_decoding_steps): all_predicts[:, timestep] = last_predictions tgt, tgt_key_padding_mask = self._encode( self._target_embedder, {"tokens": all_predicts[:, :timestep + 1]}) tgt_mask = self.generate_square_subsequent_mask(timestep + 1).to( memory.device) output = self._transformer.decoder( tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask) output_projections = self._output_projection_layer(output) if target_vocab_mask is not None: output_projections += target_vocab_mask class_probabilities = F.softmax(output_projections, dim=-1) _, predicted_classes = torch.max(class_probabilities, -1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes[timestep, :] step_predictions.append(last_predictions) if ((last_predictions == self._end_index) + (last_predictions == self._pad_index)).all(): break # shape: (num_decoding_steps, batch_size) predictions = torch.stack(step_predictions) return predictions, class_probabilities @staticmethod def _get_loss(logits: torch.FloatTensor, targets: torch.LongTensor, target_mask: torch.FloatTensor) -> torch.Tensor: logits = logits.contiguous() # shape: (batch_size, num_decoding_steps) relevant_targets = targets.contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask.contiguous() return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu: all_metrics.update(self._bleu.get_metric(reset=reset)) all_metrics['seq_acc'] = self._seq_acc.get_metric(reset=reset) return all_metrics def load_state_dict(self, state_dict, strict=True): new_state_dict = {} for k, v in state_dict.items(): if k.startswith('module.'): new_state_dict[k[len('module.'):]] = v else: new_state_dict[k] = v super(MyTransformer, self).load_state_dict(new_state_dict, strict)
class Bart(Model): """ BART model from the paper "BART: Denosing Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension" (https://arxiv.org/abs/1910.13461). The Bart model here uses a language modeling head and thus can be used for text generation. """ def __init__( self, model_name: str, vocab: Vocabulary, indexer: PretrainedTransformerIndexer = None, max_decoding_steps: int = 140, beam_size: int = 4, encoder: Seq2SeqEncoder = None, ): """ # Parameters model_name : `str`, required Name of the pre-trained BART model to use. Available options can be found in `transformers.modeling_bart.BART_PRETRAINED_MODEL_ARCHIVE_MAP`. vocab : `Vocabulary`, required Vocabulary containing source and target vocabularies. indexer : `PretrainedTransformerIndexer`, optional (default = `None`) Indexer to be used for converting decoded sequences of ids to to sequences of tokens. max_decoding_steps : `int`, optional (default = `128`) Number of decoding steps during beam search. beam_size : `int`, optional (default = `5`) Number of beams to use in beam search. The default is from the BART paper. encoder : `Seq2SeqEncoder`, optional (default = `None`) Encoder to used in BART. By default, the original BART encoder is used. """ super().__init__(vocab) self.bart = BartForConditionalGeneration.from_pretrained(model_name) self._indexer = indexer or PretrainedTransformerIndexer( model_name, namespace="tokens") self._start_id = self.bart.config.bos_token_id # CLS self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id self._end_id = self.bart.config.eos_token_id # SEP self._pad_id = self.bart.config.pad_token_id # PAD self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_id, max_steps=max_decoding_steps, beam_size=beam_size or 1) self._rouge = ROUGE( exclude_indices={self._start_id, self._pad_id, self._end_id}) self._bleu = BLEU( exclude_indices={self._start_id, self._pad_id, self._end_id}) # Replace bart encoder with given encoder. We need to extract the two embedding layers so that # we can use them in the encoder wrapper if encoder is not None: assert (encoder.get_input_dim() == encoder.get_output_dim() == self.bart.config.hidden_size) self.bart.model.encoder = _BartEncoderWrapper( encoder, self.bart.model.encoder.embed_tokens, self.bart.model.encoder.embed_positions, ) @overrides def forward( self, source_tokens: TextFieldTensors, target_tokens: TextFieldTensors = None) -> Dict[str, torch.Tensor]: """ Performs the forward step of Bart. # Parameters source_tokens : `TextFieldTensors`, required The source tokens for the encoder. We assume they are stored under the `tokens` key. target_tokens : `TextFieldTensors`, optional (default = `None`) The target tokens for the decoder. We assume they are stored under the `tokens` key. If no target tokens are given, the source tokens are shifted to the right by 1. # Returns `Dict[str, torch.Tensor]` During training, this dictionary contains the `decoder_logits` of shape `(batch_size, max_target_length, target_vocab_size)` and the `loss`. During inference, it contains `predictions` of shape `(batch_size, max_decoding_steps)` and `log_probabilities` of shape `(batch_size,)`. """ inputs = source_tokens targets = target_tokens input_ids, input_mask = inputs["tokens"]["token_ids"], inputs[ "tokens"]["mask"] outputs = {} # If no targets are provided, then shift input to right by 1. Bart already does this internally # but it does not use them for loss calculation. if targets is not None: target_ids, target_mask = targets["tokens"]["token_ids"], targets[ "tokens"]["mask"] else: target_ids = input_ids[:, 1:] target_mask = input_mask[:, 1:] if self.training: decoder_logits = self.bart( input_ids=input_ids, attention_mask=input_mask, decoder_input_ids=target_ids[:, :-1].contiguous(), decoder_attention_mask=target_mask[:, :-1].contiguous(), use_cache=False, )[0] outputs["decoder_logits"] = decoder_logits # The BART paper mentions label smoothing of 0.1 for sequence generation tasks outputs["loss"] = sequence_cross_entropy_with_logits( decoder_logits, target_ids[:, 1:].contiguous(), target_mask[:, 1:].contiguous(), label_smoothing=0.1, average="token", ) else: # Use decoder start id and start of sentence to start decoder initial_decoder_ids = torch.tensor( [[self._decoder_start_id, self._start_id]], dtype=input_ids.dtype, device=input_ids.device, ).repeat(input_ids.shape[0], 1) inital_state = { "input_ids": input_ids, "input_mask": input_mask, "encoder_states": None, } beam_result = self._beam_search.search(initial_decoder_ids, inital_state, self.take_step) predictions = beam_result[0] max_pred_indices = (beam_result[1].argmax(dim=-1).view( -1, 1, 1).expand(-1, -1, predictions.shape[-1])) predictions = predictions.gather( dim=1, index=max_pred_indices).squeeze(dim=1) self._rouge(predictions, target_ids) self._bleu(predictions, target_ids) outputs["predictions"] = predictions outputs["log_probabilities"] = (beam_result[1].gather( dim=-1, index=max_pred_indices[..., 0]).squeeze(dim=-1)) self.make_output_human_readable(outputs) return outputs @staticmethod def _decoder_cache_to_dict(decoder_cache): cache_dict = {} for layer_index, layer_cache in enumerate(decoder_cache): for attention_name, attention_cache in layer_cache.items(): for tensor_name, cache_value in attention_cache.items(): key = (layer_index, attention_name, tensor_name) cache_dict[key] = cache_value return cache_dict @staticmethod def _dict_to_decoder_cache(cache_dict): decoder_cache = [] for key, cache_value in cache_dict.items(): # Split key and extract index and dict keys layer_idx, attention_name, tensor_name = key # Extend decoder_cache to fit layer_idx + 1 layers decoder_cache = decoder_cache + [ {} for _ in range(layer_idx + 1 - len(decoder_cache)) ] cache = decoder_cache[layer_idx] if attention_name not in cache: cache[attention_name] = {} assert tensor_name not in cache[attention_name] cache[attention_name][tensor_name] = cache_value return decoder_cache def take_step(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], step: int) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take step during beam search. # Parameters last_predictions : `torch.Tensor` The predicted token ids from the previous step. Shape: `(group_size,)` state : `Dict[str, torch.Tensor]` State required to generate next set of predictions step : `int` The time step in beam search decoding. # Returns `Tuple[torch.Tensor, Dict[str, torch.Tensor]]` A tuple containing logits for the next tokens of shape `(group_size, target_vocab_size)` and an updated state dictionary. """ if len(last_predictions.shape) == 1: last_predictions = last_predictions.unsqueeze(-1) # Only the last predictions are needed for the decoder, but we need to pad the decoder ids # to not mess up the positional embeddings in the decoder. padding_size = 0 if step > 0: padding_size = step + 1 padding = torch.full( (last_predictions.shape[0], padding_size), self._pad_id, dtype=last_predictions.dtype, device=last_predictions.device, ) last_predictions = torch.cat([padding, last_predictions], dim=-1) decoder_cache = None decoder_cache_dict = { k: (state[k].contiguous() if state[k] is not None else None) for k in state if k not in {"input_ids", "input_mask", "encoder_states"} } if len(decoder_cache_dict) != 0: decoder_cache = self._dict_to_decoder_cache(decoder_cache_dict) log_probabilities = None for i in range(padding_size, last_predictions.shape[1]): encoder_outputs = ((state["encoder_states"], ) if state["encoder_states"] is not None else None) outputs = self.bart( input_ids=state["input_ids"], attention_mask=state["input_mask"], encoder_outputs=encoder_outputs, decoder_input_ids=last_predictions[:, :i + 1], past_key_values=decoder_cache, generation_mode=True, use_cache=True, ) decoder_log_probabilities = F.log_softmax(outputs[0][:, 0], dim=-1) if log_probabilities is None: log_probabilities = decoder_log_probabilities else: idx = last_predictions[:, i].view(-1, 1) log_probabilities = decoder_log_probabilities + log_probabilities.gather( dim=-1, index=idx) decoder_cache = outputs[1] state["encoder_states"] = outputs[2] if decoder_cache is not None: decoder_cache_dict = self._decoder_cache_to_dict(decoder_cache) state.update(decoder_cache_dict) return log_probabilities, state @overrides def make_output_human_readable( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: """ # Parameters output_dict : `Dict[str, torch.Tensor]` A dictionary containing a batch of predictions with key `predictions`. The tensor should have shape `(batch_size, max_sequence_length)` # Returns `Dict[str, Any]` Original `output_dict` with an additional `predicted_tokens` key that maps to a list of lists of tokens. """ predictions = output_dict["predictions"] predicted_tokens = [None] * predictions.shape[0] for i in range(predictions.shape[0]): predicted_tokens[i] = self._indexer.indices_to_tokens( {"token_ids": predictions[0].tolist()}, self.vocab) output_dict["predicted_tokens"] = predicted_tokens return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics: Dict[str, float] = {} if not self.training: metrics.update(self._rouge.get_metric(reset=reset)) metrics.update(self._bleu.get_metric(reset=reset)) return metrics
class LatentAignmentCTC(Model): def __init__( self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, upsample: torch.nn.Module = None, net: Seq2SeqEncoder = None, target_namespace: str = "target_tokens", target_embedding_dim: int = None, use_bleu: bool = True, ) -> None: super(LatentAignmentCTC, self).__init__(vocab) self._target_namespace = target_namespace self._pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) self._blank_index = self.vocab.get_token_index(SPECIAL_BLANK_TOKEN, self._target_namespace) if use_bleu: self._bleu = BLEU(exclude_indices={self._pad_index, self._blank_index}) else: self._bleu = None self._source_embedder = source_embedder source_embedding_dim = source_embedder.get_output_dim() self._upsample = upsample or LinearUpsample(source_embedding_dim, s = 3) self._net = net or StackedSelfAttentionEncoder(input_dim = source_embedding_dim, hidden_dim = 128, projection_dim = 128, feedforward_hidden_dim = 512, num_layers = 4, num_attention_heads = 4) num_classes = self.vocab.get_vocab_size(self._target_namespace) target_embedding_dim = self._net.get_output_dim() self._output_projection = torch.nn.Linear(target_embedding_dim, num_classes) @overrides def forward( self, # type: ignore source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None, ) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_input = self._source_embedder(source_tokens) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) # source_upsampled : shape : (batch_size, max_input_sequence_length, encoder_input_dim * self.s) # source_mask_upsampled : shape : (batch_size, max_input_sequence_length) source_upsampled, source_mask_upsampled = self._upsample(embedded_input, source_mask) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) net_output = self._net(source_upsampled, source_mask_upsampled) output_dict = {"source_mask_upsampled": source_mask_upsampled, "net_output": net_output} alignment_logits = self._output_projection(net_output) output_dict["alignment_logits"] = alignment_logits if target_tokens: # Compute loss. loss = self._get_loss(output_dict, target_tokens) output_dict["loss"] = loss if not self.training: alignments = alignment_logits.detach().cpu().argmax(2) predictions = self.beta_inverse(alignments) output_dict["predictions"] = predictions if target_tokens and self._bleu: self._bleu(output_dict['predictions'], target_tokens["tokens"]) #output_dict = self.decode(output_dict) #print(output_dict["predicted_tokens"]) return output_dict # TODO: too cheap. need pallalel processing def beta_inverse(self, a:torch.Tensor): """ a : size (batch, sequence) """ max_length = a.size(1) outputs = [] for sequence in a.tolist(): output = [] for token in sequence: if token == self._blank_index: continue elif len(output) == 0: output.append(token) continue elif token == output[-1]: continue else: output.append(token) pad_list = [self._pad_index] * (max_length - len(output)) outputs.append(output + pad_list) return torch.LongTensor(outputs) # @staticmethod def _get_loss(self, output_dict: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor]) -> torch.Tensor: targets = target_tokens["tokens"] target_mask = util.get_text_field_mask(target_tokens) # shape: (batch_size, input_length, target_size) alignment_logits = output_dict["alignment_logits"] # shape: (batch_size, input_length) source_mask_upsampled = output_dict["source_mask_upsampled"] #return util.sequence_cross_entropy_with_logits(alignment_logits, targets, source_mask_upsampled) return sequence_ctc_loss_with_logits(alignment_logits, source_mask_upsampled, targets, target_mask, self._blank_index) @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. """ predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: indices = list(indices) # remove pad if self._pad_index in indices: indices = indices[: indices.index(self._pad_index)] # lookup predicted_tokens = [ self.vocab.get_token_from_index(x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) # provide "tokens" and "predicted_tokens" for output. output_dict["predicted_tokens"] = all_predicted_tokens del output_dict["alignment_logits"], output_dict['source_mask_upsampled'], output_dict['net_output'] return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu and not self.training: all_metrics.update(self._bleu.get_metric(reset=reset)) return all_metrics
class BleuTest(AllenNlpTestCase): def setUp(self): super().setUp() self.metric = BLEU(ngram_weights=(0.5, 0.5), exclude_indices={0}) @multi_device def test_get_valid_tokens_mask(self, device: str): tensor = torch.tensor([[1, 2, 3, 0], [0, 1, 1, 0]], device=device) result = self.metric._get_valid_tokens_mask(tensor).long() check = torch.tensor([[1, 1, 1, 0], [0, 1, 1, 0]], device=device) assert_allclose(result, check) @multi_device def test_ngrams(self, device: str): tensor = torch.tensor([1, 2, 3, 1, 2, 0], device=device) # Unigrams. counts = Counter(self.metric._ngrams(tensor, 1)) unigram_check = {(1,): 2, (2,): 2, (3,): 1} assert counts == unigram_check # Bigrams. counts = Counter(self.metric._ngrams(tensor, 2)) bigram_check = {(1, 2): 2, (2, 3): 1, (3, 1): 1} assert counts == bigram_check # Trigrams. counts = Counter(self.metric._ngrams(tensor, 3)) trigram_check = {(1, 2, 3): 1, (2, 3, 1): 1, (3, 1, 2): 1} assert counts == trigram_check # ngram size too big, no ngrams produced. counts = Counter(self.metric._ngrams(tensor, 7)) assert counts == {} @multi_device def test_bleu_computed_correctly(self, device: str): self.metric.reset() # shape: (batch_size, max_sequence_length) predictions = torch.tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]], device=device) # shape: (batch_size, max_gold_sequence_length) gold_targets = torch.tensor([[2, 0, 0], [1, 0, 0], [1, 1, 2]], device=device) self.metric(predictions, gold_targets) assert self.metric._prediction_lengths == 6 assert self.metric._reference_lengths == 5 # Number of unigrams in predicted sentences that match gold sentences # (but not more than maximum occurrence of gold unigram within batch). assert self.metric._precision_matches[1] == ( 0 + 1 # no matches in first sentence. + 2 # one clipped match in second sentence. # two clipped matches in third sentence. ) # Total number of predicted unigrams. assert self.metric._precision_totals[1] == (1 + 2 + 3) # Number of bigrams in predicted sentences that match gold sentences # (but not more than maximum occurrence of gold bigram within batch). assert self.metric._precision_matches[2] == (0 + 0 + 1) # Total number of predicted bigrams. assert self.metric._precision_totals[2] == (0 + 1 + 2) # Brevity penalty should be 1.0 assert self.metric._get_brevity_penalty() == 1.0 bleu = self.metric.get_metric(reset=True)["BLEU"] check = math.exp(0.5 * (math.log(3) - math.log(6)) + 0.5 * (math.log(1) - math.log(3))) assert_allclose(bleu, check) @multi_device def test_bleu_computed_with_zero_counts(self, device: str): self.metric.reset() assert self.metric.get_metric()["BLEU"] == 0
class SimpleSeq2Seq(Model): """ This ``SimpleSeq2Seq`` class is a :class:`Model` which takes a sequence, encodes it, and then uses the encoded representations to decode another sequence. You can use this as the basis for a neural machine translation system, an abstractive summarization system, or any other common seq2seq problem. The model here is simple, but should be a decent starting place for implementing recent models for these tasks. Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (`tokens`) or the target tokens can have a different namespace, in which case it needs to be specified as `target_namespace`. source_embedder : ``TextFieldEmbedder``, required Embedder for source side sequences encoder : ``Seq2SeqEncoder``, required The encoder of the "encoder/decoder" model max_decoding_steps : ``int`` Maximum length of decoded sequences. target_namespace : ``str``, optional (default = 'target_tokens') If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. target_embedding_dim : ``int``, optional (default = source_embedding_dim) You can specify an embedding dimensionality for the target side. If not, we'll use the same value as the source embedder's. attention : ``Attention``, optional (default = None) If you want to use attention to get a dynamic summary of the encoder outputs at each step of decoding, this is the function used to compute similarity between the decoder hidden state and encoder outputs. attention_function: ``SimilarityFunction``, optional (default = None) This is if you want to use the legacy implementation of attention. This will be deprecated since it consumes more memory than the specialized attention modules. beam_size : ``int``, optional (default = None) Width of the beam for beam search. If not specified, greedy decoding is used. scheduled_sampling_ratio : ``float``, optional (default = 0.) At each timestep during training, we sample a random number between 0 and 1, and if it is not less than this value, we use the ground truth labels for the whole batch. Else, we use the predictions from the previous time step for the whole batch. If this value is 0.0 (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not using target side ground truth labels. See the following paper for more information: `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al., 2015 <https://arxiv.org/abs/1506.03099>`_. use_bleu : ``bool``, optional (default = True) If True, the BLEU metric will be calculated during validation. """ def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, max_decoding_steps: int, attention: Attention = None, attention_function: SimilarityFunction = None, beam_size: int = None, target_namespace: str = "tokens", target_embedding_dim: int = None, scheduled_sampling_ratio: float = 0., use_bleu: bool = True, extra_vocab=None) -> None: super(SimpleSeq2Seq, self).__init__(vocab) self.extra_vocab = extra_vocab self._target_namespace = target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={ pad_index, self._end_index, self._start_index }) else: self._bleu = None # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source vocab tokens. self._source_embedder = source_embedder # Encodes the sequence of source embeddings into a sequence of hidden states. self._encoder = encoder num_classes = self.vocab.get_vocab_size(self._target_namespace) # Attention mechanism applied to the encoder output for each step. if attention: if attention_function: raise ConfigurationError( "You can only specify an attention module or an " "attention function, but not both.") self._attention = attention elif attention_function: self._attention = LegacyAttention(attention_function) else: self._attention = None # Dense embedding of vocab words in the target space. target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim( ) self._target_embedder = Embedding(num_classes, target_embedding_dim) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = self._encoder.get_output_dim() self._decoder_output_dim = self._encoder_output_dim if self._attention: # If using attention, a weighted average over encoder outputs will be concatenated # to the previous target embedding to form the input to the decoder at each # time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim else: # Otherwise, the input to the decoder is just the previous target embedding. self._decoder_input_dim = target_embedding_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. # TODO (pradeep): Do not hardcode decoder cell type. self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections( last_predictions, state) #print("output_projections:",output_projections) #print("State:",state) '''tt = state['vocabTensor'] print ((tt==1).nonzero()) t = tt[0].tolist() print ("VOCAB TENSOR:-------------------",t) dd = [] for ii in t: for jj in ii: if(jj==1): xxx = ii.index(jj) #print (xxx,self.vocab.get_token_index(xxx)) dd.append(self.vocab.get_token_from_index(xxx)) print("dd",dd) ''' # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) #print ("class_log_probabilities:",class_log_probabilities) #class_log_probabilities = class_log_probabilities * state['vocabTensor'] #print ("class_log_probabilities:--",class_log_probabilities) #print ("output_projections:") #print("State:",state) #print("class_log_probabilities:",class_log_probabilities, class_log_probabilities.shape) #input("WAITIII") return class_log_probabilities, state def createRestrictedVocabMask( self, source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None): #print(list(source_tokens['tokens'][0])) #print(set(target_tokens['tokens'][0])) #print (target_tokens['tokens'][0].tolist()) #extraVocab = set(target_tokens['tokens'][0].tolist()) - set(source_tokens['tokens'][0].tolist()) #Remove Padding element vocabList = [] extraVocabIdx = [] # converting to vocab index for i in self.extra_vocab: extraVocabIdx.append(self.vocab.get_token_index(i)) #print (extraVocabIdx) #input("createRestrictedVocabMask") for j in range(source_tokens['tokens'].size()[0]): #allIthTokenIndex = set(extraVocabIdx).union(set(source_tokens['tokens'][j].tolist())) allIthTokenIndex = set(source_tokens['tokens'][j].tolist()) #if 0 in allIthTokenIndex: # allIthTokenIndex.remove(0) #print(allIthTokenIndex) #print(self.vocab.get_token_from_index(0)) #print(self.vocab.get_vocab_size()) x = [0 for _ in range(self.vocab.get_vocab_size())] for i in allIthTokenIndex: ''' print("ss:", self.vocab.get_token_index("@@SEP@@")) print("ss:", self.vocab.get_token_index("@PADDING@")) print("ss:", self.vocab.get_token_index("@end@")) print("ss:", self.vocab.get_token_index("@start@")) ''' #if i not in [0,1,2,3,7]: if i not in [0, 7]: x[i] = 1 vocabList.append(x) vocabTensor = torch.FloatTensor(vocabList) #print("vocabTensor ",vocabTensor.shape) #print(vocabList) #print (self.vocab.get_token_from_index(3)) #return vocabTensor.to("cuda") return vocabTensor def toStringFromIdx(self, a): strA = "" for i in a: strA += self.vocab.get_token_from_index(i) + " " print(strA) def findVocabTensorTokens(self, vocabTensorInput): a = [] for eachWord in vocabTensorInput[0].tolist(): if eachWord == 1: a.append(eachWord) toStringFromIdx(a) @overrides def forward( self, # type: ignore source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Make foward pass with decoder logic for producing the entire target sequence. Parameters ---------- source_tokens : ``Dict[str, torch.LongTensor]`` The output of `TextField.as_array()` applied on the source `TextField`. This will be passed through a `TextFieldEmbedder` and then through an encoder. target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) Output of `Textfield.as_array()` applied on target `TextField`. We assume that the target tokens are also represented as a `TextField`. Returns ------- Dict[str, torch.Tensor] """ #if not self.training: #print("source_tokens:",source_tokens) #print("target_tokens:",target_tokens) #print ("\nself.training:",self.training) """if not self.training: print ("NOT Training: Forward") self.toStringFromIdx(source_tokens['tokens'][0].tolist()) self.toStringFromIdx(target_tokens['tokens'][0].tolist()) """ #input("TTTT") state = self._encode(source_tokens) #shape (batch;max_sentence_length) #print ("STATE:source_masks",state["source_mask"], state["source_mask"].size()) #[50, 23, 256] #print ("STATE:encoder_outputs", state["encoder_outputs"], state["encoder_outputs"].size()) #print ("Source Tokens:",source_tokens) #[50, 23] #print ("size:",source_tokens['tokens'].size()) #input() if target_tokens: #print ("after:Target Token") state = self._init_decoder_state(state) # The `_forward_loop` decodes the input sequence and computes the loss during training # and validation. #print("State:",state) #print ("source_tokens:",source_tokens) #print("target_tokens:",target_tokens) #state["src"] = source_tokens #state["tgt"] = target_tokens #vocabTensor = self.createRestrictedVocabMask(source_tokens,target_tokens) '''if not self.training: print("prior vocabtensor:") self.findVocabTensorTokens(vocabTensor)''' output_dict = self._forward_loop(state, target_tokens) else: output_dict = {} if not self.training: #print("NOT Training:") #print("Forward: State",state) state = self._init_decoder_state(state) #state['vocabTensor']=vocabTensor predictions = self._forward_beam_search(state) output_dict.update(predictions) if target_tokens and self._bleu: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._bleu(best_predictions, target_tokens["tokens"]) #if not self.training: #print(output_dict) #input("WAIT") return output_dict @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ predicted_indices = output_dict["predictions"] #print("decode:",output_dict) if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] #print("decode:predicted_indices",predicted_indices) for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. #print("indices",indices) if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens ###print('output_dict["predicted_tokens"]',output_dict["predicted_tokens"]) #input("WAIT2:") return output_dict def _encode( self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) #print("source_tokens:",source_tokens) #input("encode:source_tokens") embedded_input = self._source_embedder(source_tokens) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = self._encoder(embedded_input, source_mask) return { "source_mask": source_mask, "encoder_outputs": encoder_outputs, } def _init_decoder_state( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( state["encoder_outputs"], state["source_mask"], self._encoder.is_bidirectional()) # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) #print("final_encoder_output",final_encoder_output,final_encoder_output.size()) state["decoder_hidden"] = final_encoder_output # shape: (batch_size, decoder_output_dim) state["decoder_context"] = state["encoder_outputs"].new_zeros( batch_size, self._decoder_output_dim) #print('state["decoder_context"]:',state["decoder_context"], state["decoder_context"].size()) return state def _forward_loop( self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size()[0] '''if not self.training: print ("NOT Training: Forward_loop")''' if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full((batch_size, ), fill_value=self._start_index) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] #print("num_decoding_steps:",num_decoding_steps) for timestep in range(num_decoding_steps): '''if self.training: print("timestep:",timestep)''' if self.training and torch.rand( 1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size,) input_choices = last_predictions elif not target_tokens: # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] #input("NICE1") #print("input_choices:",input_choices,len(input_choices)) # shape: (batch_size, num_classes) output_projections, state = self._prepare_output_projections( input_choices, state) #print("output_projections:",output_projections,output_projections.shape) #print("output_projections::",output_projections) #print("vocabTensor:",vocabTensor) #output_projections = output_projections * vocabTensor #print("output_projections==", output_projections) #print ("output_projections_vocab",output_projections,output_projections.shape) #input('output_projections') # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) #print("step_logits:",step_logits,len(step_logits)) # shape: (batch_size, num_classes) class_probabilities = F.softmax(output_projections, dim=-1) #print("class_probabilities:",class_probabilities,class_probabilities.shape) '''if not self.training: print("post vocabtensor:") self.findVocabTensorTokens(vocabTensor) input("AFTER") ''' #class_probabilities = class_probabilities * vocabTensor # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(class_probabilities, 1) #print("predicted_classes:",predicted_classes,len(predicted_classes)) '''if self.training: self.toStringFromIdx(predicted_classes.tolist())''' #input("NICE2") # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) output_dict = {"predictions": predictions} if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) #print("logits:",logits) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss #print("output_dict",output_dict) #print("predictions",output_dict['predictions'].shape) '''if not self.training: self.toStringFromIdx(output_dict['predictions'][0].tolist()) #input("_forward_loop") ''' return output_dict def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) #print(state) #input("11") #print(start_predictions) #print(self.take_step) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } #print (output_dict) return output_dict def _prepare_output_projections(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # pylint: disable=line-too-long """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (group_size, decoder_output_dim) decoder_context = state["decoder_context"] # shape: (group_size, target_embedding_dim) embedded_input = self._target_embedder(last_predictions) if self._attention: # shape: (group_size, encoder_output_dim) attended_input = self._prepare_attended_input( decoder_hidden, encoder_outputs, source_mask) # shape: (group_size, decoder_output_dim + target_embedding_dim) decoder_input = torch.cat((attended_input, embedded_input), -1) else: # shape: (group_size, target_embedding_dim) decoder_input = embedded_input #print("decoder_input:",decoder_input) # shape (decoder_hidden): (batch_size, decoder_output_dim) # shape (decoder_context): (batch_size, decoder_output_dim) decoder_hidden, decoder_context = self._decoder_cell( decoder_input, (decoder_hidden, decoder_context)) state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context # shape: (group_size, num_classes) output_projections = self._output_projection_layer(decoder_hidden) #print("output_projections :",output_projections) #print ("state",state) #input("HELLO") return output_projections, state def _prepare_attended_input( self, decoder_hidden_state: torch.LongTensor = None, encoder_outputs: torch.LongTensor = None, encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor: """Apply attention over encoder outputs and decoder state.""" # Ensure mask is also a FloatTensor. Or else the multiplication within # attention will complain. # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs_mask = encoder_outputs_mask.float() # shape: (batch_size, max_input_sequence_length) input_weights = self._attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask) # shape: (batch_size, encoder_output_dim) attended_input = util.weighted_sum(encoder_outputs, input_weights) return attended_input @staticmethod def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu and not self.training: all_metrics.update(self._bleu.get_metric(reset=reset)) return all_metrics
class BleuTest(AllenNlpTestCase): def setUp(self): super().setUp() self.metric = BLEU(ngram_weights=(0.5, 0.5), exclude_indices={0}) def test_get_valid_tokens_mask(self): tensor = torch.tensor([[1, 2, 3, 0], [0, 1, 1, 0]]) result = self.metric._get_valid_tokens_mask(tensor) result = result.long().numpy() check = np.array([[1, 1, 1, 0], [0, 1, 1, 0]]) np.testing.assert_array_equal(result, check) def test_ngrams(self): tensor = torch.tensor([1, 2, 3, 1, 2, 0]) # Unigrams. counts = Counter(self.metric._ngrams(tensor, 1)) unigram_check = {(1,): 2, (2,): 2, (3,): 1} assert counts == unigram_check # Bigrams. counts = Counter(self.metric._ngrams(tensor, 2)) bigram_check = {(1, 2): 2, (2, 3): 1, (3, 1): 1} assert counts == bigram_check # Trigrams. counts = Counter(self.metric._ngrams(tensor, 3)) trigram_check = {(1, 2, 3): 1, (2, 3, 1): 1, (3, 1, 2): 1} assert counts == trigram_check # ngram size too big, no ngrams produced. counts = Counter(self.metric._ngrams(tensor, 7)) assert counts == {} def test_bleu_computed_correctly(self): self.metric.reset() # shape: (batch_size, max_sequence_length) predictions = torch.tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]) # shape: (batch_size, max_gold_sequence_length) gold_targets = torch.tensor([[2, 0, 0], [1, 0, 0], [1, 1, 2]]) self.metric(predictions, gold_targets) assert self.metric._prediction_lengths == 6 assert self.metric._reference_lengths == 5 # Number of unigrams in predicted sentences that match gold sentences # (but not more than maximum occurence of gold unigram within batch). assert self.metric._precision_matches[1] == ( 0 + # no matches in first sentence. 1 + # one clipped match in second sentence. 2 # two clipped matches in third sentence. ) # Total number of predicted unigrams. assert self.metric._precision_totals[1] == ( 1 + 2 + 3 ) # Number of bigrams in predicted sentences that match gold sentences # (but not more than maximum occurence of gold bigram within batch). assert self.metric._precision_matches[2] == ( 0 + 0 + 1 ) # Total number of predicted bigrams. assert self.metric._precision_totals[2] == ( 0 + 1 + 2 ) # Brevity penalty should be 1.0 assert self.metric._get_brevity_penalty() == 1.0 bleu = self.metric.get_metric(reset=True)["BLEU"] check = math.exp(0.5 * (math.log(3) - math.log(6)) + 0.5 * (math.log(1) - math.log(3))) np.testing.assert_approx_equal(bleu, check) def test_bleu_computed_with_zero_counts(self): self.metric.reset() assert self.metric.get_metric()["BLEU"] == 0
class AssociativeSeq2SeqChainedAttention(Model): def __init__( self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, target_embedder: TextFieldEmbedder, source_encoder: Seq2SeqEncoder, target_encoder: Seq2SeqEncoder, #instance_reference_similarity_function:MatrixAttention, #target_instance_ref_similarity_function:MatrixAttention, max_decoding_steps: int, attention: Attention = None, s2s_attention: Attention = None, t2t_attention: Attention = None, beam_size: int = None, target_namespace: str = "tokens", scheduled_sampling_ratio: float = 0., use_bleu: bool = True) -> None: super(AssociativeSeq2SeqChainedAttention, self).__init__(vocab) self._target_namespace = target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={ pad_index, self._end_index, self._start_index }) else: self._bleu = None # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source vocab tokens. self._source_embedder = source_embedder # Encodes the sequence of source embeddings into a sequence of hidden states. self._source_encoder = source_encoder self._target_encoder = target_encoder self._encoder_output_dim = self._target_encoder.get_output_dim() self._decoder_output_dim = self._encoder_output_dim target_embedding_dim = target_embedder.get_output_dim() if attention: self._attention = attention self._s2s_attention = s2s_attention self._t2t_attention = t2t_attention self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim else: raise NotImplementedError num_classes = self.vocab.get_vocab_size(self._target_namespace) self._target_embedder = target_embedder self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections( last_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def forward( self, # type: ignore ref_source_tokens: Dict[str, torch.LongTensor], instance_source_tokens: Dict[str, torch.LongTensor], ref_target_tokens: Dict[str, torch.LongTensor], instance_target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: state = self._encode(ref_target_tokens, ref_source_tokens, instance_source_tokens) if instance_target_tokens: state = self._init_decoder_state(state) # The `_forward_loop` decodes the input sequence and computes the loss during training # and validation. output_dict = self._forward_loop(state, instance_target_tokens) else: output_dict = {} if not self.training: state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) if instance_target_tokens and self._bleu: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._bleu(best_predictions, instance_target_tokens["tokens"]) return output_dict @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _encode( self, ref_target_tokens: Dict[str, torch.Tensor], ref_source_tokens: Dict[str, torch.Tensor], instance_source_tokens: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_ref_target = self._target_embedder(ref_target_tokens) ref_target_mask = util.get_text_field_mask(ref_target_tokens) encoded_ref_target = self._target_encoder(embedded_ref_target, ref_target_mask) embedded_ref_source = self._source_embedder(ref_source_tokens) ref_source_mask = util.get_text_field_mask(ref_source_tokens) encoded_ref_source = self._source_encoder(embedded_ref_source, ref_source_mask) embedded_instance_source = self._source_embedder( instance_source_tokens) instance_source_mask = util.get_text_field_mask(instance_source_tokens) encoded_instance_source = self._source_encoder( embedded_instance_source, instance_source_mask) return { "source_mask": ref_target_mask, "encoder_outputs": encoded_ref_target, "ref_source_mask": ref_source_mask, "ref_source_encoded": embedded_ref_source, "instance_source_mask": instance_source_mask, "instance_source_encoded": embedded_instance_source } def _init_decoder_state( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( state["encoder_outputs"], state["source_mask"], self._target_encoder.is_bidirectional()) # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = final_encoder_output # shape: (batch_size, decoder_output_dim) state["decoder_context"] = state["encoder_outputs"].new_zeros( batch_size, self._decoder_output_dim) return state def _forward_loop( self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size()[0] if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full((batch_size, ), fill_value=self._start_index) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand( 1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size,) input_choices = last_predictions elif not target_tokens: # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] # shape: (batch_size, num_classes) output_projections, state = self._prepare_output_projections( input_choices, state) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape: (batch_size, num_classes) class_probabilities = F.softmax(output_projections, dim=-1) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(class_probabilities, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) output_dict = {"predictions": predictions} if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss return output_dict def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _prepare_output_projections(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # pylint: disable=line-too-long """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (group_size, decoder_output_dim) decoder_context = state["decoder_context"] # shape: (group_size, target_embedding_dim) embedded_input = self._target_embedder._token_embedders['tokens']( last_predictions) if self._attention: # shape: (group_size, encoder_output_dim) attended_input_source = self._prepare_attended_input( decoder_hidden, state['encoder_outputs'], state['source_mask'], self._t2t_attention) attended_input_ref_source = self._prepare_attended_input( attended_input_source, state['ref_source_encoded'], state['ref_source_mask'], self._attention) attended_input_instance_source = self._prepare_attended_input( attended_input_ref_source, state['instance_source_encoded'], state['instance_source_mask'], self._s2s_attention) # shape: (group_size, decoder_output_dim + target_embedding_dim) decoder_input = torch.cat( (attended_input_instance_source, embedded_input), -1) else: raise NotImplementedError # shape (decoder_hidden): (batch_size, decoder_output_dim) # shape (decoder_context): (batch_size, decoder_output_dim) decoder_hidden, decoder_context = self._decoder_cell( decoder_input, (decoder_hidden, decoder_context)) state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context # shape: (group_size, num_classes) output_projections = self._output_projection_layer(decoder_hidden) return output_projections, state def _prepare_attended_input(self, decoder_hidden_state: torch.LongTensor = None, encoder_outputs: torch.LongTensor = None, encoder_outputs_mask: torch.LongTensor = None, attention=None) -> torch.Tensor: """Apply attention over encoder outputs and decoder state.""" # Ensure mask is also a FloatTensor. Or else the multiplication within # attention will complain. # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs_mask = encoder_outputs_mask.float() # shape: (batch_size, max_input_sequence_length) input_weights = attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask) # shape: (batch_size, encoder_output_dim) attended_input = util.weighted_sum(encoder_outputs, input_weights) return attended_input @staticmethod def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu and not self.training: all_metrics.update(self._bleu.get_metric(reset=reset)) return all_metrics