def generate(self, code: torch.Tensor, label: torch.Tensor, max_length: int, beam_size: int, lp_alpha: float) -> torch.Tensor: start_index = self.vocab.get_token_index('<s>') end_index = self.vocab.get_token_index('</s>') beam_search = BeamSearch(end_index=end_index, max_steps=max_length, beam_size=beam_size, per_node_beam_size=3) batch_size = code.shape[0] start_predictions = ( torch.empty(batch_size).to(label).fill_(start_index)) zero_state = code.new_zeros(batch_size, self._generator._module.hidden_size) start_state = { 'h': zero_state, 'c': zero_state, 'code': code, 'label': label, 'length': label.new_ones(batch_size), 'length_alpha': (code.new_empty(batch_size).fill_(lp_alpha)) } all_predictions, last_log_probs = beam_search.search( start_predictions=start_predictions, start_state=start_state, step=self.beam_search_step) return all_predictions
def test_beam_search_matches_greedy(self): model = self.trained_model state = model._states["xintent"] beam_search = BeamSearch(model._end_index, max_steps=model._max_decoding_steps, beam_size=1) final_encoder_output = self.get_sample_encoded_output() batch_size = final_encoder_output.size()[0] start_predictions = final_encoder_output.new_full( (batch_size,), fill_value=model._start_index, dtype=torch.long) start_state = {"decoder_hidden": final_encoder_output} greedy_prediction = model.greedy_predict( final_encoder_output=final_encoder_output, target_embedder=state.embedder, decoder_cell=state.decoder_cell, output_projection_layer=state.output_projection_layer ) greedy_tokens = model.decode_all(greedy_prediction) (beam_predictions, _) = beam_search.search( start_predictions, start_state, state.take_step) beam_prediction = beam_predictions[0] beam_tokens = model.decode_all(beam_prediction) assert beam_tokens == greedy_tokens
def test_greedy_decode_matches_beam_search(self): beam_search = BeamSearch( self.model._end_index, max_steps=self.model._max_decoding_steps, beam_size=1 ) training_tensors = self.dataset.as_tensor_dict() # Get greedy predictions from _forward_loop method of model. state = self.model._encode(training_tensors["source_tokens"]) state = self.model._init_decoder_state(state) output_dict_greedy = self.model._forward_loop(state) output_dict_greedy = self.model.decode(output_dict_greedy) # Get greedy predictions from beam search (beam size = 1). state = self.model._encode(training_tensors["source_tokens"]) state = self.model._init_decoder_state(state) batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size,), fill_value=self.model._start_index ) all_top_k_predictions, _ = beam_search.search( start_predictions, state, self.model.take_step ) output_dict_beam_search = {"predictions": all_top_k_predictions} output_dict_beam_search = self.model.decode(output_dict_beam_search) # Predictions from model._forward_loop and beam_search should match. assert output_dict_greedy["predicted_tokens"] == output_dict_beam_search["predicted_tokens"]
def _check_results( self, batch_size: int = 5, expected_top_k: np.array = None, expected_log_probs: np.array = None, beam_search: BeamSearch = None, state: Dict[str, torch.Tensor] = None, take_step=take_step_with_timestep, ) -> None: expected_top_k = expected_top_k if expected_top_k is not None else self.expected_top_k expected_log_probs = (expected_log_probs if expected_log_probs is not None else self.expected_log_probs) state = state or {} beam_search = beam_search or self.beam_search beam_size = beam_search.beam_size initial_predictions = torch.tensor([0] * batch_size) top_k, log_probs = beam_search.search(initial_predictions, state, take_step) # type: ignore # top_k should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(top_k.size())[:-1] == [batch_size, beam_size] np.testing.assert_array_equal(top_k[0].numpy(), expected_top_k) # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(log_probs.size()) == [batch_size, beam_size] np.testing.assert_allclose(log_probs[0].numpy(), expected_log_probs, rtol=1e-6)
def test_beam_search_matches_greedy(self): model = self.trained_model state = model._states["xintent"] beam_search = BeamSearch(model._end_index, max_steps=model._max_decoding_steps, beam_size=1) final_encoder_output = self.get_sample_encoded_output() batch_size = final_encoder_output.size()[0] start_predictions = final_encoder_output.new_full( (batch_size, ), fill_value=model._start_index, dtype=torch.long) start_state = {"decoder_hidden": final_encoder_output} greedy_prediction = model.greedy_predict( final_encoder_output=final_encoder_output, target_embedder=state.embedder, decoder_cell=state.decoder_cell, output_projection_layer=state.output_projection_layer, ) greedy_tokens = model.decode_all(greedy_prediction) (beam_predictions, _) = beam_search.search(start_predictions, start_state, state.take_step) beam_prediction = beam_predictions[0] beam_tokens = model.decode_all(beam_prediction) assert beam_tokens == greedy_tokens
def test_greedy_decode_matches_beam_search(self): # pylint: disable=protected-access beam_search = BeamSearch(self.model._end_index, max_steps=self.model._max_decoding_steps, beam_size=1) training_tensors = self.dataset.as_tensor_dict() # Get greedy predictions from _forward_loop method of model. state = self.model._encode(training_tensors["source_tokens"]) state = self.model._init_decoder_state(state) output_dict_greedy = self.model._forward_loop(state) output_dict_greedy = self.model.decode(output_dict_greedy) # Get greedy predictions from beam search (beam size = 1). state = self.model._encode(training_tensors["source_tokens"]) state = self.model._init_decoder_state(state) batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full((batch_size,), fill_value=self.model._start_index) all_top_k_predictions, _ = beam_search.search( start_predictions, state, self.model.take_step) output_dict_beam_search = { "predictions": all_top_k_predictions, } output_dict_beam_search = self.model.decode(output_dict_beam_search) # Predictions from model._forward_loop and beam_search should match. assert output_dict_greedy["predicted_tokens"] == output_dict_beam_search["predicted_tokens"]
def test_empty_sequences(self): initial_predictions = torch.LongTensor( [self.end_index - 1, self.end_index - 1]) beam_search = BeamSearch(self.end_index, beam_size=1) with pytest.warns(RuntimeWarning, match="Empty sequences predicted"): predictions, log_probs = beam_search.search( initial_predictions, {}, take_step_with_timestep) # predictions hould have shape `(batch_size, beam_size, max_predicted_length)`. assert list(predictions.size()) == [2, 1, 1] # log probs hould have shape `(batch_size, beam_size)`. assert list(log_probs.size()) == [2, 1] assert (predictions == self.end_index).all() assert (log_probs == 0).all()
def _check_results(self, batch_size: int = 5, expected_top_k: np.array = None, expected_log_probs: np.array = None, beam_search: BeamSearch = None, state: Dict[str, torch.Tensor] = None) -> None: expected_top_k = expected_top_k if expected_top_k is not None else self.expected_top_k expected_log_probs = expected_log_probs if expected_log_probs is not None else self.expected_log_probs state = state or {} beam_search = beam_search or self.beam_search beam_size = beam_search.beam_size initial_predictions = torch.tensor([0] * batch_size) # pylint: disable=not-callable top_k, log_probs = beam_search.search(initial_predictions, state, take_step) # type: ignore # top_k should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(top_k.size())[:-1] == [batch_size, beam_size] np.testing.assert_array_equal(top_k[0].numpy(), expected_top_k) # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(log_probs.size()) == [batch_size, beam_size] np.testing.assert_allclose(log_probs[0].numpy(), expected_log_probs)
class MASS(Model): """ This ``MASS`` class is a :class:`Model` which takes a sequence, encodes it, and then uses the encoded representations to decode another sequence. You can use this as the basis for a neural machine translation system, an abstractive summarization system, or any other common seq2seq problem. The model here is simple, but should be a decent starting place for implementing recent models for these tasks. Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (`tokens`) or the target tokens can have a different namespace, in which case it needs to be specified as `target_namespace`. source_embedder : ``TextFieldEmbedder``, required Embedder for source side sequences encoder : ``Seq2SeqEncoder``, required The encoder of the "encoder/decoder" model max_decoding_steps : ``int`` Maximum length of decoded sequences. target_namespace : ``str``, optional (default = 'target_tokens') If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. target_embedding_dim : ``int``, optional (default = source_embedding_dim) You can specify an embedding dimensionality for the target side. If not, we'll use the same value as the source embedder's. attention : ``Attention``, optional (default = None) If you want to use attention to get a dynamic summary of the encoder outputs at each step of decoding, this is the function used to compute similarity between the decoder hidden state and encoder outputs. attention_function: ``SimilarityFunction``, optional (default = None) This is if you want to use the legacy implementation of attention. This will be deprecated since it consumes more memory than the specialized attention modules. beam_size : ``int``, optional (default = None) Width of the beam for beam search. If not specified, greedy decoding is used. scheduled_sampling_ratio : ``float``, optional (default = 0.) At each timestep during training, we sample a random number between 0 and 1, and if it is not less than this value, we use the ground truth labels for the whole batch. Else, we use the predictions from the previous time step for the whole batch. If this value is 0.0 (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not using target side ground truth labels. See the following paper for more information: `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al., 2015 <https://arxiv.org/abs/1506.03099>`_. use_bleu : ``bool``, optional (default = True) If True, the BLEU metric will be calculated during validation. """ def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, transformer_encoder: Seq2SeqEncoder, max_decoding_steps: int, beam_size: int = None, target_namespace: str = "tokens", target_embedding_dim: int = None, use_bleu: bool = True, use_fp16: bool = False, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(MASS, self).__init__(vocab) self._target_namespace = target_namespace # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) self._mask_index = self.vocab.get_token_index('[MASK]', self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={ pad_index, self._end_index, self._start_index, self._mask_index }) else: self._bleu = None # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source vocab tokens. self._source_embedder = source_embedder # Encodes the sequence of source embeddings into a sequence of hidden states. self._encoder = transformer_encoder # Dense embedding of vocab words in the target space. target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim( ) self._target_embedder = source_embedder # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = self._encoder.get_output_dim() self._decoder_output_dim = self._encoder_output_dim self._decoder_input_dim = target_embedding_dim self._decoder = TransformerDecoder( use_fp16, self._target_embedder, decoder_layers=6, dropout=0.1, decoder_embed_dim=self._encoder_output_dim, decoder_ffn_embed_dim=target_embedding_dim, decoder_attention_heads=4, decoder_output_dim=self._decoder_output_dim, max_target_positions=512, attention_dropout=0.1, ) initializer(self) def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ start_predictions = {'tokens': last_predictions.unsqueeze(1)} # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections( start_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) class_log_probabilities = class_log_probabilities.squeeze(1) return class_log_probabilities, state @overrides def forward( self, # type: ignore encoder_tokens: Dict[str, torch.LongTensor], decoder_tokens: Dict[str, torch.LongTensor] = None, target_tokens: Dict[str, torch.LongTensor] = None, positions: torch.LongTensor = None, ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Make foward pass with decoder logic for producing the entire target sequence. Parameters ---------- encoder_tokens : ``Dict[str, torch.LongTensor]`` The output of `TextField.as_array()` applied on the source `TextField`. This will be passed through a `TextFieldEmbedder` and then through an encoder. target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) Output of `Textfield.as_array()` applied on target `TextField`. We assume that the target tokens are also represented as a `TextField`. Returns ------- Dict[str, torch.Tensor] """ state = self._encode(encoder_tokens) if decoder_tokens: # The `_forward_loop` decodes the input sequence and computes the loss during training # and validation. output_dict = self._forward_loop(state, decoder_tokens=decoder_tokens, target_tokens=target_tokens, positions=positions) else: output_dict = {} if not self.training: if not decoder_tokens: predictions = self._forward_beam_search(state) output_dict.update(predictions) if target_tokens and self._bleu: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._bleu(best_predictions, target_tokens["tokens"]) else: if target_tokens and self._bleu: best_predictions = output_dict["predictions"] self._bleu(best_predictions, target_tokens["tokens"]) return output_dict def _encode( self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_input = self._source_embedder(source_tokens) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = self._encoder(embedded_input, source_mask) return { "source_mask": source_mask, "encoder_outputs": encoder_outputs, } @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _forward_loop( self, state: Dict[str, torch.Tensor], decoder_tokens: Dict[str, torch.LongTensor] = None, target_tokens: Dict[str, torch.LongTensor] = None, positions: torch.LongTensor = None, ) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ decoder_token_mask = util.get_text_field_mask(decoder_tokens) decoder_padding_mask = (decoder_token_mask == 0) # shape: (batch_size, num_classes) logits, state = self._prepare_output_projections( decoder_tokens, state, decoder_padding_mask=decoder_padding_mask, positions=positions) # shape: (batch_size, num_classes) class_probabilities = F.softmax(logits, dim=-1) # shape (predicted_classes): (batch_size,) _, predictions = torch.max(class_probabilities, -1) output_dict = {"predictions": predictions} if target_tokens is not None: # Compute loss. target_mask = decoder_token_mask targets = target_tokens["tokens"] loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss return output_dict def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size = state["source_mask"].size(0) start_tokens = state["source_mask"].new_full( (batch_size, ), fill_value=self._mask_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_tokens, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _prepare_output_projections(self, last_predictions: Dict[str, torch.LongTensor], state: Dict[str, torch.Tensor], decoder_padding_mask: torch.LongTensor = None, positions: torch.LongTensor = None) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # pylint: disable=line-too-long """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_out = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) encoder_padding_mask = (state["source_mask"] == 0) # shape: (group_size, target_embedding_dim) prev_output_tokens = last_predictions decoder_output = self._decoder( prev_output_tokens, encoder_out=encoder_out, encoder_padding_mask=encoder_padding_mask, decoder_padding_mask=decoder_padding_mask, positions=positions) return decoder_output, state @staticmethod def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ return util.sequence_cross_entropy_with_logits(logits, targets, target_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu and not self.training: all_metrics.update(self._bleu.get_metric(reset=reset)) return all_metrics
class CustomAutoRegressiveSeqDecoder(SeqDecoder): def __init__( self, vocab: Vocabulary, decoder_net: DecoderNet, max_decoding_steps: int, target_embedder: Embedding, target_namespace: str = "tokens", tie_output_embedding: bool = False, scheduled_sampling_ratio: float = 0, label_smoothing_ratio: Optional[float] = None, beam_size: int = 4, tensor_based_metric: Metric = None, token_based_metric: Metric = None, ) -> None: super().__init__(target_embedder) self._vocab = vocab # Decodes the sequence of encoded hidden states into e new sequence of hidden states. self._decoder_net = decoder_net self._max_decoding_steps = max_decoding_steps self._target_namespace = target_namespace self._label_smoothing_ratio = label_smoothing_ratio # At prediction time, we use a beam search to find the most likely sequence of target tokens. # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self._vocab.get_token_index( START_SYMBOL, self._target_namespace) self._end_index = self._vocab.get_token_index(END_SYMBOL, self._target_namespace) self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) target_vocab_size = self._vocab.get_vocab_size(self._target_namespace) if self.target_embedder.get_output_dim( ) != self._decoder_net.target_embedding_dim: raise ConfigurationError( "Target Embedder output_dim doesn't match decoder module's input." ) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear( self._decoder_net.get_output_dim(), target_vocab_size) if tie_output_embedding: if self._output_projection_layer.weight.shape != self.target_embedder.weight.shape: raise ConfigurationError( "Can't tie embeddings with output linear layer, due to shape mismatch" ) self._output_projection_layer.weight = self.target_embedder.weight # These metrics will be updated during training and validation self._tensor_based_metric = tensor_based_metric self._token_based_metric = token_based_metric self._scheduled_sampling_ratio = scheduled_sampling_ratio def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _forward_loss( self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] # Prepare embeddings for targets. They will be used as gold embeddings during decoder training # shape: (batch_size, max_target_sequence_length, embedding_dim) target_embedding = self.target_embedder(targets) # shape: (batch_size, max_target_batch_sequence_length) target_mask = util.get_text_field_mask(target_tokens) if self._scheduled_sampling_ratio == 0 and self._decoder_net.decodes_parallel: _, decoder_output = self._decoder_net( previous_state=state, previous_steps_predictions=target_embedding[:, :-1, :], encoder_outputs=encoder_outputs, source_mask=source_mask, previous_steps_mask=target_mask[:, :-1]) # shape: (group_size, max_target_sequence_length, num_classes) logits = self._output_projection_layer(decoder_output) else: batch_size = source_mask.size()[0] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full( (batch_size, ), fill_value=self._start_index) # shape: (steps, batch_size, target_embedding_dim) steps_embeddings = torch.Tensor([]) step_logits: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand( 1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size, steps, target_embedding_dim) state['previous_steps_predictions'] = steps_embeddings # shape: (batch_size, ) effective_last_prediction = last_predictions else: # shape: (batch_size, ) effective_last_prediction = targets[:, timestep] if timestep == 0: state['previous_steps_predictions'] = torch.Tensor([]) else: # shape: (batch_size, steps, target_embedding_dim) state[ 'previous_steps_predictions'] = target_embedding[:, : timestep] # shape: (batch_size, num_classes) output_projections, state = self._prepare_output_projections( effective_last_prediction, state) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(output_projections, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes # shape: (batch_size, 1, target_embedding_dim) last_predictions_embeddings = self.target_embedder( last_predictions).unsqueeze(1) # This step is required, since we want to keep up two different prediction history: gold and real if steps_embeddings.shape[-1] == 0: # pylint: disable=unsubscriptable-object # There is no previous steps, except for start vectors in ``last_predictions`` # shape: (group_size, 1, target_embedding_dim) steps_embeddings = last_predictions_embeddings else: # shape: (group_size, steps_count, target_embedding_dim) steps_embeddings = torch.cat( [steps_embeddings, last_predictions_embeddings], 1) # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) # TODO: We will be using beam search to get predictions for validation, but if beam size in 1 # we could consider taking the last_predictions here and building step_predictions # and use that instead of running beam search again, if performance in validation is taking a hit output_dict = {'loss': loss} return output_dict def _prepare_output_projections( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, steps_count, decoder_output_dim) previous_steps_predictions = state.get("previous_steps_predictions") # shape: (batch_size, 1, target_embedding_dim) last_predictions_embeddings = self.target_embedder( last_predictions).unsqueeze(1) if previous_steps_predictions is None or previous_steps_predictions.shape[ -1] == 0: # There is no previous steps, except for start vectors in ``last_predictions`` # shape: (group_size, 1, target_embedding_dim) previous_steps_predictions = last_predictions_embeddings else: # shape: (group_size, steps_count, target_embedding_dim) previous_steps_predictions = torch.cat( [previous_steps_predictions, last_predictions_embeddings], 1) decoder_state, decoder_output = self._decoder_net( previous_state=state, encoder_outputs=encoder_outputs, source_mask=source_mask, previous_steps_predictions=previous_steps_predictions) state["previous_steps_predictions"] = previous_steps_predictions # Update state with new decoder state, override previous state state.update(decoder_state) if self._decoder_net.decodes_parallel: decoder_output = decoder_output[:, -1, :] # shape: (group_size, num_classes) output_projections = self._output_projection_layer(decoder_output) return output_projections, state def _get_loss(self, logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits( logits, relevant_targets, relevant_mask, label_smoothing=self._label_smoothing_ratio) def get_output_dim(self): return self._decoder_net.get_output_dim() def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections( last_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if not self.training: if self._tensor_based_metric is not None: all_metrics.update( self._tensor_based_metric.get_metric( reset=reset)) # type: ignore if self._token_based_metric is not None: all_metrics.update( self._token_based_metric.get_metric( reset=reset)) # type: ignore return all_metrics @overrides def forward( self, encoder_out: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: state = encoder_out decoder_init_state = self._decoder_net.init_decoder_state(state) state.update(decoder_init_state) output_dict = self._forward_loss( state, target_tokens) if target_tokens else {} if not self.training: predictions = self._forward_beam_search(state) output_dict.update(predictions) if target_tokens: if self._tensor_based_metric is not None: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] # shape: (batch_size, target_sequence_length) self._tensor_based_metric( best_predictions, target_tokens["tokens"]) # type: ignore if self._token_based_metric is not None: output_dict = self.decode(output_dict) predicted_tokens = output_dict['predicted_tokens'] self._token_based_metric( predicted_tokens, # type: ignore [y.text for y in target_tokens["tokens"][1:-1]]) return output_dict @overrides def post_process( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self._vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict
class PhnMoChA(Model): """ This ``SimpleSeq2Seq`` class is a :class:`Model` which takes a sequence, encodes it, and then uses the encoded representations to decode another sequence. You can use this as the basis for a neural machine translation system, an abstractive summarization system, or any other common seq2seq problem. The model here is simple, but should be a decent starting place for implementing recent models for these tasks. Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (`tokens`) or the target tokens can have a different namespace, in which case it needs to be specified as `target_namespace`. source_embedder : ``TextFieldEmbedder``, required Embedder for source side sequences encoder : ``Seq2SeqEncoder``, required The encoder of the "encoder/decoder" model max_decoding_steps : ``int`` Maximum length of decoded sequences. target_namespace : ``str``, optional (default = 'tokens') If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. target_embedding_dim : ``int``, optional (default = source_embedding_dim) You can specify an embedding dimensionality for the target side. If not, we'll use the same value as the source embedder's. attention : ``Attention``, optional (default = None) If you want to use attention to get a dynamic summary of the encoder outputs at each step of decoding, this is the function used to compute similarity between the decoder hidden state and encoder outputs. attention_function: ``SimilarityFunction``, optional (default = None) This is if you want to use the legacy implementation of attention. This will be deprecated since it consumes more memory than the specialized attention modules. beam_size : ``int``, optional (default = None) Width of the beam for beam search. If not specified, greedy decoding is used. scheduled_sampling_ratio : ``float``, optional (default = 0.) At each timestep during training, we sample a random number between 0 and 1, and if it is not less than this value, we use the ground truth labels for the whole batch. Else, we use the predictions from the previous time step for the whole batch. If this value is 0.0 (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not using target side ground truth labels. See the following paper for more information: `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al., 2015 <https://arxiv.org/abs/1506.03099>`_. use_bleu : ``bool``, optional (default = True) If True, the BLEU metric will be calculated during validation. """ def __init__( self, vocab: Vocabulary, encoder: Seq2SeqEncoder, input_size: int, target_embedding_dim: int, decoder_hidden_dim: int, max_decoding_steps: int, max_decoding_ratio: float = 1.5, dep_parser: Model = None, pos_tagger: Model = None, cmvn: str = 'none', delta: int = 0, time_mask_width: int = 0, freq_mask_width: int = 0, time_mask_max_ratio: float = 0.0, dec_layers: int = 1, layerwise_pretraining: List[Tuple[int, int]] = None, cnn: Seq2SeqEncoder = None, conv_lstm: Seq2SeqEncoder = None, train_at_phn_level: bool = False, rnnt_layer: Model = None, phn_ctc_layer: Model = None, ctc_layer: Model = None, projection_layer: nn.Module = None, tie_proj: bool = False, att_ratio: float = 0.0, dep_ratio: float = 0.0, pos_ratio: float = 0.0, attention: Attention = None, attention_function: SimilarityFunction = None, latency_penalty: float = 0.0, loss_type: str = "nll", beam_size: int = 1, target_namespace: str = "tokens", phoneme_target_namespace: str = "phonemes", dropout: float = 0.0, blank: str = "_", sampling_strategy: str = "max", from_candidates: bool = False, scheduled_sampling_ratio: float = 0., initializer: InitializerApplicator = InitializerApplicator() ) -> None: super(PhnMoChA, self).__init__(vocab) self._input_size = input_size self._target_namespace = target_namespace self._phn_target_namespace = phoneme_target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio self._sampling_strategy = sampling_strategy self._train_at_phn_level = train_at_phn_level self._blank = blank self._dep_parser = dep_parser self._pos_tagger = pos_tagger self._ctc_layer = ctc_layer self._rnnt_layer = rnnt_layer self._phn_ctc_layer = phn_ctc_layer self._projection_layer = projection_layer if tie_proj: self._rnnt_layer.set_projection_layer(projection_layer) self._att_ratio = att_ratio self._dep_ratio = dep_ratio self._pos_ratio = pos_ratio self._loss_type = loss_type # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) self._pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._phn_pad_index = self.vocab.get_token_index( self.vocab._padding_token, self._phn_target_namespace) # pylint: disable=protected-access exclude_indices = {self._pad_index, self._end_index, self._start_index} self._logs: Dict[str, Union[Metric, None]] = { "att_wer": (WER(exclude_indices=exclude_indices) if self._att_ratio > 0 else None), "att_bleu": (BLEU(exclude_indices=exclude_indices) if self._att_ratio > 0 else None), "att_loss": (Average() if self._att_ratio > 0 else None), "phn_ctc_loss": (Average() if self._phn_ctc_layer else None), "ctc_loss": (Average() if self._ctc_layer else None), "rnnt_loss": (Average() if self._rnnt_layer else None), "dal_loss": (Average() if latency_penalty > 0.0 else None), "dep_loss": (Average() if self._dep_parser else None), "pos_loss": (Average() if self._pos_tagger else None), "tag_loss": (Average() if self._dep_parser else None), "arc_loss": (Average() if self._dep_parser else None) } # At prediction time, we use a beam search to find the most likely sequence of target tokens. self._max_decoding_steps = max_decoding_steps self._max_decoding_ratio = max_decoding_ratio self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Encodes the sequence of source embeddings into a sequence of hidden states. self._encoder = encoder self._cnn = cnn self._conv_lstm = conv_lstm num_classes = self.vocab.get_vocab_size(self._target_namespace) self._num_classes = num_classes # Attention mechanism applied to the encoder output for each step. if attention: if attention_function: raise ConfigurationError( "You can only specify an attention module or an " "attention function, but not both.") self._attention = attention elif attention_function: self._attention = LegacyAttention(attention_function) else: self._attention = None # Dense embedding of vocab words in the target space. self._target_embedder = Embedding(num_classes, target_embedding_dim) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = self._encoder.get_output_dim() self._decoder_output_dim = decoder_hidden_dim self._dec_layers = dec_layers if self._decoder_output_dim != self._encoder_output_dim: self.bridge = nn.Linear(self._encoder_output_dim, self._dec_layers * self._decoder_output_dim, bias=False) if self._attention: # If using attention, a weighted average over encoder outputs will be concatenated # to the previous target embedding to form the input to the decoder at each # time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim self.att_out = Linear(self._decoder_output_dim + self._encoder_output_dim, self._decoder_output_dim, bias=True) else: # Otherwise, the input to the decoder is just the previous target embedding. self._decoder_input_dim = target_embedding_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. # TODO (pradeep): Do not hardcode decoder cell type. self._decoder = nn.LSTM(self._decoder_input_dim, self._decoder_output_dim, num_layers=self._dec_layers, batch_first=True) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) self._input_norm = lambda x: x if cmvn == 'global': self._input_norm = nn.BatchNorm1d(self._input_size * (delta + 1)) elif cmvn == 'utt': self._input_norm = nn.InstanceNorm1d(self._input_size * (delta + 1)) self._delta = None if delta > 0: self._delta = Delta(order=delta) self._epoch_num = float("inf") self._layerwise_pretraining = layerwise_pretraining try: if isinstance(self._encoder, PytorchSeq2SeqWrapper): self._num_layers = self._encoder._module.num_layers else: self._num_layers = self._encoder.num_layers except AttributeError: self._num_layers = float("inf") self._output_layer_num = self._num_layers self._loss = None self._from_candidates = from_candidates if loss_type == "ocd": self._loss = OCDLoss(self._end_index, 1e-7, 1e-7, 5) elif loss_type == "edocd": self._loss = EDOCDLoss(self._end_index, 1e-7, 1e-7, 5) self._latency_penalty = latency_penalty self._target_granularity = self._target_namespace self.time_mask = TimeMask(time_mask_width, time_mask_max_ratio) self.freq_mask = FreqMask(freq_mask_width) initializer(self) def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections( last_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def forward( self, # type: ignore source_features: torch.FloatTensor, source_lengths: torch.LongTensor, target_tokens: Dict[str, torch.LongTensor] = None, words: Dict[str, torch.LongTensor] = None, segments: torch.LongTensor = None, pos_tags: torch.LongTensor = None, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, epoch_num: int = None, dataset: str = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Make foward pass with decoder logic for producing the entire target sequence. Parameters ---------- source_tokens : ``Dict[str, torch.LongTensor]`` The output of `TextField.as_array()` applied on the source `TextField`. This will be passed through a `TextFieldEmbedder` and then through an encoder. target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) Output of `Textfield.as_array()` applied on target `TextField`. We assume that the target tokens are also represented as a `TextField`. Returns ------- Dict[str, torch.Tensor] """ output_dict = {} if dataset is not None: self._target_granularity = dataset[0] if epoch_num is not None: self._epoch_num = epoch_num[0] self.set_output_layer_num() source_mask = util.get_mask_from_sequence_lengths( source_lengths, source_features.size(1)).bool() source_features = source_features.unsqueeze(1) # make a channel dim if self._delta: source_features = self._delta(source_features) batch_size, n_channels, timesteps, feature_size = source_features.size( ) source_features = self._input_norm( source_features.transpose(-2, -1).reshape(batch_size, -1, timesteps)) \ .view(batch_size, n_channels, feature_size, timesteps).transpose(-2, -1) source_features = self.time_mask(source_features, source_mask) source_features = self.freq_mask(source_features, source_mask) source_features = source_features.masked_fill( ~source_mask.unsqueeze(1).unsqueeze(-1).expand_as(source_features), 0.0) state = self._encode(source_features, source_lengths) source_lengths = util.get_lengths_from_binary_sequence_mask( state["source_mask"]) target_tokens["mask"] = (target_tokens[self._target_namespace] != self._pad_index).bool() if self._phn_ctc_layer and \ (self._phn_target_namespace in self._target_granularity or self._train_at_phn_level): raise NotImplementedError # logits = self._projection_layer(state["encoder_outputs"]) # phn_ctc_output_dict = self._phn_ctc_layer(logits, source_lengths, target_tokens) # output_dict.update({f"phn_ctc_{key}": value for key, value in phn_ctc_output_dict.items()}) if self._rnnt_layer is not None and self._rnnt_layer.loss_ratio > 0.0: rnnt_output_dict = self._rnnt_layer(state["encoder_outputs"], source_lengths, target_tokens) output_dict.update({ f"rnnt_{key}": value for key, value in rnnt_output_dict.items() }) if self._ctc_layer is not None and self._ctc_layer.loss_ratio > 0.0: logits = self._projection_layer(state["encoder_outputs"]) ctc_output_dict = self._ctc_layer(logits, source_lengths, target_tokens) output_dict.update({ f"ctc_{key}": value for key, value in ctc_output_dict.items() }) if target_tokens and self._att_ratio > 0.0 and \ self._target_granularity == self._target_namespace: targets = target_tokens[self._target_namespace] output_dict["target_tokens"] = targets target_mask = util.get_text_field_mask(target_tokens) if self._train_at_phn_level: raise NotImplementedError # state = self._get_phn_level_representations( # state["encoder_outputs"].detach().requires_grad_(True), # state["source_mask"], # output_dict["phn_ctc"]) state = self._init_decoder_state(state) output_dict.update(self._forward_loop(state, target_tokens)) self._logs["att_wer"](output_dict["predictions"], targets) if self._dep_parser or self._pos_tagger: relevant_mask = target_mask[:, 1:] attention_contexts, _ = _remove_eos( output_dict["attention_contexts"], relevant_mask) if segments is not None: segments, _ = remove_sentence_boundaries( segments, target_mask) attention_contexts, _ = \ char_to_word(attention_contexts, segments) contexts = {"tokens": attention_contexts} if self._dep_parser: parser_outputs = self._dep_parser(contexts, pos_tags, metadata, head_tags, head_indices) parser_outputs["dep_loss"] = parser_outputs.pop("loss") output_dict.update(parser_outputs) if self._pos_tagger: tagger_outputs = self._pos_tagger(contexts, pos_tags, metadata) tagger_outputs["pos_loss"] = tagger_outputs.pop("loss") output_dict.update(tagger_outputs) if not self.training: if self._target_granularity == self._target_namespace: if self._att_ratio > 0.0: state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) if target_tokens: targets = target_tokens[self._target_namespace] # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._logs["att_bleu"](best_predictions, targets) self._logs["att_wer"](best_predictions, targets) log_dict = self.decode(output_dict) verbose_target = [ self._indices_to_tokens(tokens.tolist()[1:]) for tokens in target_tokens[self._target_namespace] ] verbose_best_pred = [ beams[0] for beams in log_dict["predicted_tokens"] ] sep = " " if self._target_namespace == 'tokens' else "" with open(f"preds.{epoch_num[0]}.txt", "a+") as fp: fp.write("\n".join([ sep.join( map(lambda s: re.sub(self._blank, " ", s), words)) for words in verbose_best_pred ])) fp.write("\n") with open(f"golds.{epoch_num[0]}.txt", "a+") as fp: fp.write("\n".join([ sep.join( map(lambda s: re.sub(self._blank, " ", s), words)) for words in verbose_target ])) fp.write("\n") # for gold, pred in zip(verbose_target, verbose_best_pred): # print(gold, pred) if self.training: output_dict = self._collect_losses( output_dict, ctc=(self._ctc_layer.loss_ratio if self._ctc_layer else 0), rnnt=(self._rnnt_layer.loss_ratio if self._rnnt_layer else 0), att=self._att_ratio, dal=self._latency_penalty, dep=self._dep_ratio, pos=self._pos_ratio) if torch.isnan(output_dict["loss"]).any() or \ (torch.abs(output_dict["loss"]) == float('inf')).any(): for key, _ in output_dict.items(): if "loss" in key: output_dict[key] = output_dict[key].new_zeros( size=(), requires_grad=True).clone() self._update_metrics(output_dict) return output_dict def _indices_to_tokens(self, indices): # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index(x, namespace=self._target_namespace) for x in indices ] return predicted_tokens @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ def _decode_predictions(input_key: str, output_key: str, beam=False): if input_key in output_dict: if beam: all_predicted_tokens = [ list(map(self._indices_to_tokens, beams)) for beams in sanitize(output_dict[input_key]) ] else: all_predicted_tokens = list( map(self._indices_to_tokens, sanitize(output_dict[input_key]))) output_dict[output_key] = all_predicted_tokens _decode_predictions("predictions", "predicted_tokens", beam=True) _decode_predictions("ctc_predictions", "ctc_predicted_tokens") _decode_predictions("rnnt_predictions", "rnnt_predicted_tokens") _decode_predictions("target_tokens", "targets") return output_dict def _encode(self, source_features: torch.FloatTensor, source_lengths: torch.LongTensor) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) if self._cnn is not None: source_features, source_lengths = self._cnn( source_features, source_lengths) source_mask = util.get_mask_from_sequence_lengths( source_lengths, source_features.size(1)) if self._conv_lstm is not None: source_features = self._conv_lstm(source_features, source_mask) if not isinstance(self._encoder, AWDRNN): encoder_outputs = self._encoder(source_features, source_mask) else: encoder_outputs, _, source_lengths = self._encoder( source_features, source_lengths, self._output_layer_num) source_mask = util.get_mask_from_sequence_lengths( source_lengths, encoder_outputs.size(1)) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) return {"source_mask": source_mask, "encoder_outputs": encoder_outputs} def _get_phn_level_representations( self, features: torch.FloatTensor, mask: torch.BoolTensor, phn_log_probs: torch.Tensor) -> Dict[str, torch.Tensor]: phn_enc_outs, segment_lengths = averaging_tensor_of_same_label( features, phn_log_probs, mask=mask) state = { "encoder_outputs": phn_enc_outs, "source_mask": util.get_mask_from_sequence_lengths(segment_lengths, int(max(segment_lengths))) } return state def _init_decoder_state( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) encoder_outputs = state["encoder_outputs"] source_mask = state["source_mask"] final_encoder_output = util.get_final_encoder_states( encoder_outputs, source_mask, self._encoder.is_bidirectional()) if self._encoder_output_dim != self._dec_layers * self._decoder_output_dim: final_encoder_output = self.bridge(final_encoder_output) initial_decoder_input = final_encoder_output.view(-1, self._dec_layers, self._decoder_output_dim) \ .contiguous() # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = initial_decoder_input state["decoder_output"] = initial_decoder_input[:, 0] # shape: (batch_size, decoder_output_dim) state["decoder_context"] = encoder_outputs.new_zeros( batch_size, self._dec_layers, self._decoder_output_dim) state["attention"] = None if isinstance(self._attention, StatefulAttention): state["att_keys"], state["att_values"] = \ self._attention.init_state(encoder_outputs) return state def _forward_loop( self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size()[0] candidates = None if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens[self._target_namespace] _, target_sequence_length = targets.size() if self._loss is not None: candidates = target_to_candidates(targets, self._num_classes, ignore_indices=[ self._pad_index, self._start_index, self._end_index ]) # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. if isinstance(self._loss, EDOCDLoss): num_decoding_steps = int( target_sequence_length * self._max_decoding_ratio) - 1 else: num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full((batch_size, ), fill_value=self._start_index) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] step_attns: List[torch.Tensor] = [] step_attn_cxts: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand( 1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size,) input_choices = last_predictions elif self._loss is not None: # shape: (batch_size,) input_choices = last_predictions elif not target_tokens: # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] # shape: (batch_size, num_classes) output_projections, state = self._prepare_output_projections( input_choices, state) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # list of tensors, shape: (batch_size, 1, num_encoding_steps) if self._attention: step_attns.append(state["attention"].unsqueeze(1)) step_attn_cxts.append(state["attention_contexts"].unsqueeze(1)) # shape: (batch_size, num_classes) class_probabilities = F.softmax(output_projections, dim=-1) # shape (predicted_classes): (batch_size,) predicted_classes = maybe_sample_from_candidates( class_probabilities, candidates=(candidates if self._from_candidates else None), strategy=self._sampling_strategy) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) output_dict = { "predictions": predictions, } # shape: (batch_size, num_decoding_steps, num_encoding_steps) if self._attention: output_dict["attentions"] = torch.cat(step_attns, 1) output_dict["attention_contexts"] = torch.cat(step_attn_cxts, 1) if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, predictions, targets, target_mask, candidates) output_dict["att_loss"] = loss if self._latency_penalty > 0.0: DAL = differentiable_average_lagging(output_dict["attentions"], source_mask, target_mask[:, 1:]) output_dict["dal"] = DAL return output_dict def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions } return output_dict def _prepare_output_projections(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # pylint: disable=line-too-long """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (group_size, decoder_output_dim) decoder_context = state["decoder_context"] # shape: (group_size, decoder_output_dim) decoder_output = state["decoder_output"] attention = state["attention"] # shape: (group_size, target_embedding_dim) embedded_input = self._target_embedder(last_predictions) # shape: (group_size, decoder_output_dim + target_embedding_dim) decoder_input = torch.cat((embedded_input, decoder_output), -1) # shape (decoder_hidden): (batch_size, decoder_output_dim) # shape (decoder_context): (batch_size, decoder_output_dim) outputs, (decoder_hidden, decoder_context) = self._decoder( decoder_input.unsqueeze(1), (decoder_hidden.transpose(1, 0).contiguous(), decoder_context.transpose(1, 0).contiguous())) decoder_hidden = decoder_hidden.transpose(1, 0).contiguous() decoder_context = decoder_context.transpose(1, 0).contiguous() outputs = outputs.squeeze(1) if self._attention: # shape: (group_size, encoder_output_dim) attended_output, attention = self._prepare_attended_output( outputs, state) # shape: (group_size, decoder_output_dim) decoder_output = torch.tanh( self.att_out(torch.cat((attended_output, outputs), -1))) state["attention"] = attention state["attention_contexts"] = attended_output else: # shape: (group_size, target_embedding_dim) decoder_output = outputs state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context state["decoder_output"] = decoder_output # shape: (group_size, num_classes) output_projections = self._output_projection_layer(decoder_output) return output_projections, state def _prepare_attended_output( self, decoder_hidden_state: torch.Tensor, state: Dict[str, torch.Tensor]) -> torch.Tensor: """Apply attention over encoder outputs and decoder state.""" # Ensure mask is also a FloatTensor. Or else the multiplication within # attention will complain. # shape: (batch_size, max_input_sequence_length) encoder_outputs = state["encoder_outputs"] source_mask = state["source_mask"] prev_attention = state["attention"] att_keys = state["att_keys"] att_values = state["att_values"] # shape: (batch_size, max_input_sequence_length) mode = "soft" if self.training else "hard" if isinstance(self._attention, MonotonicAttention): encoder_outs: Dict[str, torch.Tensor] = { "value": state["encoder_outputs"], "mask": state["source_mask"] } monotonic_attention, chunk_attention = self._attention( encoder_outs, decoder_hidden_state, prev_attention, mode=mode) # shape: (batch_size, encoder_output_dim) attended_output = util.weighted_sum(encoder_outputs, chunk_attention) attention = monotonic_attention elif isinstance(self._attention, StatefulAttention): attended_output, attention = self._attention( decoder_hidden_state, att_keys, att_values, source_mask) else: attention = self._attention(decoder_hidden_state, source_mask) attended_output = util.weighted_sum(encoder_outputs, attention) return attended_output, attention # @staticmethod def _get_loss(self, logits: torch.FloatTensor, predictions: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor, candidates: torch.LongTensor = None) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() if self._loss is not None: if isinstance(self._loss, OCDLoss) or isinstance( self._loss, EDOCDLoss): self._loss.update_temperature(self._epoch_num) if isinstance(self._loss, EDOCDLoss): log_probs = F.log_softmax(logits, dim=-1) return self._loss(log_probs, predictions, relevant_targets, relevant_mask) else: raise NotImplementedError return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) def _collect_losses(self, output_dict: Dict[str, torch.Tensor], phn_ctc: float = 1.0, ctc: float = 1.0, rnnt: float = 1.0, att: float = 1.0, dep: float = 1.0, pos: float = 1.0, dal: float = 1.0) -> torch.Tensor: loss = 0.0 if "phn_ctc_loss" in output_dict: loss += phn_ctc * output_dict["phn_ctc_loss"] if "ctc_loss" in output_dict: loss += ctc * output_dict["ctc_loss"] if "rnnt_loss" in output_dict: loss += rnnt * output_dict["rnnt_loss"] if "att_loss" in output_dict: loss += att * output_dict["att_loss"] if "dep_loss" in output_dict: loss += dep * output_dict["dep_loss"] if "pos_loss" in output_dict: loss += pos * output_dict["pos_loss"] if "dal" in output_dict: loss += dal * output_dict["dal"] output_dict["loss"] = loss return output_dict def _update_metrics(self, output_dict: Dict[str, torch.Tensor]) -> torch.Tensor: for key, track_func in self._logs.items(): try: value = output_dict[key] value = value.item() if isinstance(value, torch.Tensor) else value track_func(value) except KeyError: continue @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} for key, metric_tracker in self._logs.items(): if "phn" in key and self._phn_target_namespace not in self._target_granularity: continue if "att" in key and self._target_namespace not in self._target_granularity: continue if metric_tracker is not None: metric_values = metric_tracker.get_metric(reset=reset) if isinstance(metric_values, dict): all_metrics.update(metric_values) else: all_metrics[key] = metric_values if self._ctc_layer: all_metrics.update({ f"ctc_{key}": value for key, value in self._ctc_layer.get_metrics( reset=reset).items() }) if self._rnnt_layer: all_metrics.update({ f"rnnt_{key}": value for key, value in self._rnnt_layer.get_metrics( reset=reset).items() }) if not self.training: if self._dep_parser: all_metrics.update(self._dep_parser.get_metrics(reset=reset)) if self._pos_tagger: all_metrics.update(self._pos_tagger.get_metrics(reset=reset)) return all_metrics def set_output_layer_num(self): output_layer_num = self._num_layers if self._layerwise_pretraining is not None: for epoch, layer_num in self._layerwise_pretraining: if self._epoch_num < epoch: break output_layer_num = layer_num self._output_layer_num = output_layer_num return output_layer_num
class Seq2SeqClaimRank(Model): """ A ``Seq2SeqClaimRank`` model. This model is intended to be trained with a multi-instance learning objective that simultaneously tries to: - Decode the given post modifier (e.g. the ``target`` sequence). - Ensure that the model is attending to the proper claims during decoding (which are identified by the ``labels`` variable). The basic architecture is a seq2seq model with attention where the input sequence is the source sentence (without post-modifier), and the output sequence is the post-modifier. The main difference is that instead of performing attention over the input sequence, attention is performed over a collection of claims. Parameters ========== text_field_embedder : ``TextFieldEmbedder`` Embeds words in the source sentence / claims. sentence_encoder : ``Seq2VecEncoder`` Encodes the entire source sentence into a single vector. claim_encoder : ``Seq2SeqEncoder`` Encodes each claim into a single vector. attention : ``Attention`` Type of attention mechanism used. WARNING: Do not normalize attention scores, and make sure to use a sigmoid activation. Otherwise the claim ranking loss will not work properly! max_steps : ``int`` Maximum number of decoding steps. Default: 100 (same as ONMT). beam_size: ``int`` Beam size used during evaluation. Default: 5 (same as ONMT). beta: ``float`` Weight of attention loss term. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, sentence_encoder: Seq2VecEncoder, claim_encoder: Seq2SeqEncoder, attention: Attention, max_steps: int = 100, beam_size: int = 5, beta: float = 1.0) -> None: super(Seq2SeqClaimRank, self).__init__(vocab) self.text_field_embedder = text_field_embedder self.sentence_encoder = sentence_encoder self.claim_encoder = TimeDistributed(claim_encoder) # Handles additional sequence dim self.claim_encoder_dim = claim_encoder.get_output_dim() self.attention = attention self.decoder_embedding_dim = text_field_embedder.get_output_dim() self.max_steps = max_steps self.beam_size = beam_size self.beta = beta # self.target_embedder = torch.nn.Embedding(vocab.get_vocab_size(), decoder_embedding_dim) # Since we are using the sentence encoding as the initial hidden state to the decoder, the # decoder hidden dim must match the sentence encoder hidden dim. self.decoder_output_dim = sentence_encoder.get_output_dim() self.decoder_0_cell = torch.nn.LSTMCell(self.decoder_embedding_dim + self.claim_encoder_dim, self.decoder_output_dim) self.decoder_1_cell = torch.nn.LSTMCell(self.decoder_output_dim, self.decoder_output_dim) # When projecting out we will use attention to combine claim embeddings into a single # context embedding, this will be concatenated with the decoder cell output before being # fed to the projection layer. Hence the expected input size is: # decoder output dim + claim encoder output dim projection_input_dim = self.decoder_output_dim + self.claim_encoder_dim self.output_projection_layer = torch.nn.Linear(projection_input_dim, vocab.get_vocab_size()) self._start_index = self.vocab.get_token_index('<s>') self._end_index = self.vocab.get_token_index('</s>') self.beam_search = BeamSearch(self._end_index, max_steps=max_steps, beam_size=beam_size) pad_index = vocab.get_token_index(vocab._padding_token) self.bleu = BLEU(exclude_indices={pad_index, self._start_index, self._end_index}) self.avg_reconstruction_loss = Average() self.avg_claim_scoring_loss = Average() def take_step(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ output_projections, _, state = self._prepare_output_projections(last_predictions, state) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def forward(self, inputs: Dict[str, torch.LongTensor], claims: Dict[str, torch.LongTensor], targets: Dict[str, torch.LongTensor] = None, labels: torch.Tensor = None) -> torch.Tensor: """Forward pass of the model + decoder logic. Parameters ---------- inputs : ``Dict[str, torch.LongTensor]`` Output of `TextField.as_array()` from the `input` field. claims : ``Dict[str, torch.LongTensor]`` Output of `ListField.as_array()` from the `claims` field. targets : ``Dict[str, torch.LongTensor]`` Output of `TextField.as_array()` from the `target` field. Only expected during training and validation. labels : ``torch.Tensor`` Output of `LabelField.as_array()` from the `labels` field, indicating which claims were used. Only expected during training and validation. Returns ------- Dict[str, torch.Tensor] Dictionary containing loss tensor and decoder outputs. """ # Obtain an encoding for each input sentence (e.g. the contexts) input_mask = util.get_text_field_mask(inputs) input_word_embeddings = self.text_field_embedder(inputs) input_encodings = self.sentence_encoder(input_word_embeddings, input_mask) # Next we encode claims. Note that here we have two additional sequence dimensions (since # there are multiple claims per instance, and we want to apply attention at the word # level). To deal with this we need to set `num_wrapping_dims=1` for the embedder, and make # the claim encoder TimeDistributed. claim_mask = util.get_text_field_mask(claims, num_wrapping_dims=1) claim_word_embeddings = self.text_field_embedder(claims, num_wrapping_dims=1) claim_encodings = self.claim_encoder(claim_word_embeddings, claim_mask) # Package the encoder outputs into a state dictionary. state = { 'input_mask': input_mask, 'input_encodings': input_encodings, 'claim_mask': claim_mask, 'claim_encodings': claim_encodings } # If ``target`` (the post-modifier) and ``labels`` (indicator of which claims are used) are # provided then we use them to compute loss. if (targets is not None) and (labels is not None): state = self._init_decoder_state(state) output_dict = self._forward_loop(state, targets, labels) else: output_dict = {} # If model is not training, then we perform beam search for decoding to obtain higher # quality outputs. if not self.training: # Perform beam search state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) # Compute BLEU top_k_predictions = output_dict['predictions'] best_predictions = top_k_predictions[:, 0, :] self.bleu(best_predictions, targets['tokens']) return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. """ predicted_indices = output_dict['predictions'] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [self.vocab.get_token_from_index(x) for x in indices] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _init_decoder_state(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Adds fields to the state required to initialize the decoder.""" batch_size = state['input_mask'].shape[0] # First decoder layer gets jack (trying to approximate the structure in # opennmt's graphic state['decoder_0_h'] = state['input_encodings'].new_zeros(batch_size, self.decoder_output_dim) state['decoder_0_c'] = state['input_encodings'].new_zeros(batch_size, self.decoder_output_dim) # Initialize LSTM hidden state (e.g. h_0) with output of the sentence encoder. state['decoder_1_h'] = state['input_encodings'] # Initialize LSTM context state (e.g. c_0) with zeros. state['decoder_1_c'] = state['input_encodings'].new_zeros(batch_size, self.decoder_output_dim) # Initialize previous context. state['prev_context'] = state['input_encodings'].new_zeros(batch_size, self.claim_encoder_dim) return state def _forward_loop(self, state: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor], labels: torch.Tensor) -> Dict[str, torch.Tensor]: """Compute loss using greedy decoding.""" batch_size = state['input_mask'].shape[0] target_tokens = targets['tokens'] num_decoding_steps = target_tokens.shape[1] - 1 # Greedy decoding phase output_logit_list = [] attention_logit_list = [] select_idx_list = [] for timestep in range(num_decoding_steps): # Feed target sequence as input decoder_input = target_tokens[:, timestep] output_logits, attention_logits, state = self._prepare_output_projections(decoder_input, state) # Store output and attention logits output_logit_list.append(output_logits.unsqueeze(1)) attention_logit_list.append(attention_logits.unsqueeze(1)) # Compute reconstruction loss output_logit_tensor = torch.cat(output_logit_list, dim=1) relevant_target_tokens = target_tokens[:, 1:].contiguous() target_mask = util.get_text_field_mask(targets)[:, 1:].contiguous() reconstruction_loss = util.sequence_cross_entropy_with_logits(output_logit_tensor, relevant_target_tokens, target_mask) # Compute claim scoring loss. A loss is computed between **each** attention vector and the # true label. In order for that to work we need to: # a. Tile the source labels (so that they are copied for each word) # b. Mask out padding tokens - this requires taking the outer-product of the target mask # and the claim mask attention_logit_tensor = torch.cat(attention_logit_list, dim=1) claim_level_mask = (state['claim_mask'].sum(-1) > 0).long() attention_mask = target_mask.unsqueeze(-1) * claim_level_mask.unsqueeze(1) labels = labels.unsqueeze(1).repeat(1, num_decoding_steps, 1).float() claim_scoring_loss = F.binary_cross_entropy_with_logits(attention_logit_tensor, labels, reduction='none') claim_scoring_loss *= attention_mask.float() # Apply mask # We want to apply 'batch' reduction (as is done in `sequence_cross_entropy...` which # entails averaging over each dimension. denom = attention_mask for i in range(3): denom = denom.sum(-1) claim_scoring_loss = claim_scoring_loss.sum(-1) / (denom.float() + 1e-13) denom = (denom > 0) total_loss = reconstruction_loss + self.beta * claim_scoring_loss # Update metrics self.avg_reconstruction_loss(reconstruction_loss) self.avg_claim_scoring_loss(claim_scoring_loss) output_dict = { "loss": total_loss, "reconstruction_loss": reconstruction_loss, "claim_scoring_loss": claim_scoring_loss, "attention_logits": attention_logit_tensor } return output_dict def _prepare_output_projections(self, decoder_input: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]: # Embed decoder input decoder_word_embeddings = self.text_field_embedder({'tokens': decoder_input}) # Concat with previous context concat = torch.cat((decoder_word_embeddings, state['prev_context']), dim=-1) # Run forward pass of decoder RNN decoder_0_h, decoder_0_c = self.decoder_0_cell(concat, (state['decoder_0_h'], state['decoder_0_c'])) decoder_1_h, decoder_1_c = self.decoder_1_cell(decoder_0_h, (state['decoder_1_h'], state['decoder_1_c'])) state['decoder_0_h'] = decoder_0_h state['decoder_0_c'] = decoder_0_c state['decoder_1_h'] = decoder_1_h state['decoder_1_c'] = decoder_1_c # Compute attention and get context embedding. We get an attention score for each word in # each claim. Then we sum up scores to get a claim level score (so we can use overlap as # supervision). claim_encodings = state['claim_encodings'] claim_mask = state['claim_mask'] batch_size, n_claims, claim_length, dim = claim_encodings.shape flattened_claim_encodings = claim_encodings.view(batch_size, -1, dim) flattened_claim_mask = claim_mask.view(batch_size, -1) flattened_attention_logits = self.attention(decoder_1_h, flattened_claim_encodings, flattened_claim_mask) attention_logits = flattened_attention_logits.view(batch_size, n_claims, claim_length) # Now get claim level encodings by summing word level attention. word_level_attention = util.masked_softmax(attention_logits, claim_mask) claim_encodings = util.weighted_sum(claim_encodings, word_level_attention) # If not training, get max attention word to replace unk if not self.training: max_word = word_level_attention.argmax(dim=-1, keepdim=True) gathered = word_level_attention.gather(dim=-1, index=max_word) max_claim = gathered.squeeze().argmax(dim=-1, keepdim=True) max_word = max_word.squeeze().gather(dim=1, index=max_claim) select_idx = torch.cat((max_claim, max_word), dim=-1) else: select_idx = None # We compute our context directly from the claim word embeddings claim_mask = (claim_mask.sum(-1) > 0).float() attention_logits = attention_logits.sum(-1) attention_weights = torch.sigmoid(attention_logits) * claim_mask normalized_attention_weights = attention_weights / (attention_weights.sum(-1, True) + 1e-13) context_embedding = util.weighted_sum(claim_encodings, normalized_attention_weights) state['prev_context'] = context_embedding # Concatenate RNN output w/ context vector and feed through final hidden layer projection_input = torch.cat((decoder_1_h, context_embedding), dim=-1) output_logits = self.output_projection_layer(projection_input) return output_logits, attention_logits, state def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size = state['input_mask'].size()[0] start_predictions = state['input_mask'].new_full((batch_size,), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self.beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = { 'recon': self.avg_reconstruction_loss.get_metric(reset=reset).data.item(), 'claim': self.avg_claim_scoring_loss.get_metric(reset=reset).data.item() } # Only update BLEU score during validation and evaluation if not self.training: all_metrics.update(self.bleu.get_metric(reset=reset)) return all_metrics
class CustomAutoRegressiveSeqDecoder(SeqDecoder): def __init__( self, vocab: Vocabulary, decoder_net: DecoderNet, max_decoding_steps: int, target_embedder: Embedding, target_namespace: str = "tokens", tie_output_embedding: bool = False, scheduled_sampling_ratio: float = 0, label_smoothing_ratio: Optional[float] = None, beam_size: int = 4, tensor_based_metric: Metric = None, token_based_metric: Metric = None, ) -> None: super().__init__(target_embedder) self._vocab = vocab self._decoder_net = decoder_net self._max_decoding_steps = max_decoding_steps self._target_namespace = target_namespace self._label_smoothing_ratio = label_smoothing_ratio self._start_index = self._vocab.get_token_index( START_SYMBOL, self._target_namespace) self._end_index = self._vocab.get_token_index(END_SYMBOL, self._target_namespace) self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) target_vocab_size = self._vocab.get_vocab_size(self._target_namespace) if self.target_embedder.get_output_dim( ) != self._decoder_net.target_embedding_dim: raise ConfigurationError( "Target Embedder output_dim doesn't match decoder module's input." ) self._output_projection_layer = Linear( self._decoder_net.get_output_dim(), target_vocab_size) if tie_output_embedding: if self._output_projection_layer.weight.shape != self.target_embedder.weight.shape: raise ConfigurationError( "Can't tie embeddings with output linear layer, due to shape mismatch" ) self._output_projection_layer.weight = self.target_embedder.weight self._tensor_based_metric = tensor_based_metric self._token_based_metric = token_based_metric self._scheduled_sampling_ratio = scheduled_sampling_ratio def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _forward_loss( self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] # Prepare embeddings for targets. They will be used as gold embeddings during decoder training # shape: (batch_size, max_target_sequence_length, embedding_dim) target_embedding = self.target_embedder(targets) # shape: (batch_size, max_target_batch_sequence_length) target_mask = util.get_text_field_mask(target_tokens) if self._scheduled_sampling_ratio == 0 and self._decoder_net.decodes_parallel: _, decoder_output = self._decoder_net( previous_state=state, previous_steps_predictions=target_embedding[:, :-1, :], encoder_outputs=encoder_outputs, source_mask=source_mask, previous_steps_mask=target_mask[:, :-1]) # shape: (group_size, max_target_sequence_length, num_classes) logits = self._output_projection_layer(decoder_output) else: batch_size = source_mask.size()[0] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full( (batch_size, ), fill_value=self._start_index) # shape: (steps, batch_size, target_embedding_dim) steps_embeddings = torch.Tensor([]) step_logits: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand( 1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size, steps, target_embedding_dim) state['previous_steps_predictions'] = steps_embeddings # shape: (batch_size, ) effective_last_prediction = last_predictions else: # shape: (batch_size, ) effective_last_prediction = targets[:, timestep] if timestep == 0: state['previous_steps_predictions'] = torch.Tensor([]) else: # shape: (batch_size, steps, target_embedding_dim) state[ 'previous_steps_predictions'] = target_embedding[:, : timestep] # shape: (batch_size, num_classes) output_projections, state = self._prepare_output_projections( effective_last_prediction, state) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(output_projections, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes # shape: (batch_size, 1, target_embedding_dim) last_predictions_embeddings = self.target_embedder( last_predictions).unsqueeze(1) # This step is required, since we want to keep up two different prediction history: gold and real if steps_embeddings.shape[-1] == 0: # pylint: disable=unsubscriptable-object # There is no previous steps, except for start vectors in ``last_predictions`` # shape: (group_size, 1, target_embedding_dim) steps_embeddings = last_predictions_embeddings else: # shape: (group_size, steps_count, target_embedding_dim) steps_embeddings = torch.cat( [steps_embeddings, last_predictions_embeddings], 1) # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) # TODO: We will be using beam search to get predictions for validation, but if beam size in 1 # we could consider taking the last_predictions here and building step_predictions # and use that instead of running beam search again, if performance in validation is taking a hit output_dict = {'loss': loss} return output_dict def _prepare_output_projections( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, steps_count, decoder_output_dim) previous_steps_predictions = state.get("previous_steps_predictions") # shape: (batch_size, 1, target_embedding_dim) last_predictions_embeddings = self.target_embedder( last_predictions).unsqueeze(1) if previous_steps_predictions is None or previous_steps_predictions.shape[ -1] == 0: # There is no previous steps, except for start vectors in ``last_predictions`` # shape: (group_size, 1, target_embedding_dim) previous_steps_predictions = last_predictions_embeddings else: # shape: (group_size, steps_count, target_embedding_dim) previous_steps_predictions = torch.cat( [previous_steps_predictions, last_predictions_embeddings], 1) decoder_state, decoder_output = self._decoder_net( previous_state=state, encoder_outputs=encoder_outputs, source_mask=source_mask, previous_steps_predictions=previous_steps_predictions) state["previous_steps_predictions"] = previous_steps_predictions # Update state with new decoder state, override previous state state.update(decoder_state) if self._decoder_net.decodes_parallel: decoder_output = decoder_output[:, -1, :] # shape: (group_size, num_classes) output_projections = self._output_projection_layer(decoder_output) return output_projections, state def _get_loss(self, logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.Tensor: # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits( logits, relevant_targets, relevant_mask, label_smoothing=self._label_smoothing_ratio) def get_output_dim(self): return self._decoder_net.get_output_dim() def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections( last_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if not self.training: if self._tensor_based_metric is not None: all_metrics.update( self._tensor_based_metric.get_metric( reset=reset)) # type: ignore if self._token_based_metric is not None: all_metrics.update( self._token_based_metric.get_metric( reset=reset)) # type: ignore return all_metrics @overrides def forward( self, encoder_out: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: state = encoder_out decoder_init_state = self._decoder_net.init_decoder_state(state) state.update(decoder_init_state) output_dict = self._forward_loss( state, target_tokens) if target_tokens else {} if not self.training: predictions = self._forward_beam_search(state) output_dict.update(predictions) if target_tokens: if self._tensor_based_metric is not None: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] # shape: (batch_size, target_sequence_length) self._tensor_based_metric( best_predictions, target_tokens["tokens"]) # type: ignore if self._token_based_metric is not None: output_dict = self.decode(output_dict) predicted_tokens = output_dict['predicted_tokens'] self._token_based_metric( predicted_tokens, # type: ignore [y.text for y in target_tokens["tokens"][1:-1]]) return output_dict @overrides def post_process( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self._vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict
class Editor(Model): def __init__( self, vocab: Vocabulary, embed: TextFieldEmbedder, encoder_size: int, decoder_size: int, num_layers: int, beam_size: int, max_decoding_steps: int, use_bleu: bool = True, initializer: InitializerApplicator = InitializerApplicator() ) -> None: super().__init__(vocab) self.START, self.END = self.vocab.get_token_index( START_SYMBOL), self.vocab.get_token_index(END_SYMBOL) self.OOV = self.vocab.get_token_index(self.vocab._oov_token) # pylint: disable=protected-access self.PAD = self.vocab.get_token_index(self.vocab._padding_token) # pylint: disable=protected-access self.COPY = self.vocab.get_token_index("@@COPY@@") self.KEEP = self.vocab.get_token_index("@@KEEP@@") self.DROP = self.vocab.get_token_index("@@DROP@@") self.SYMBOL = (self.START, self.END, self.PAD, self.KEEP, self.DROP) self.vocab_size = vocab.get_vocab_size() self.EMB = embed self.emb_size = self.EMB.token_embedder_tokens.output_dim self.encoder_size, self.decoder_size = encoder_size, decoder_size self.FACT_ENCODER = FeedForward(3 * self.emb_size, 1, encoder_size, nn.Tanh()) self.ATTN = AdditiveAttention(encoder_size + decoder_size, encoder_size) self.COPY_ATTN = AdditiveAttention(decoder_size, encoder_size) module = nn.LSTM(self.emb_size, encoder_size // 2, num_layers, bidirectional=True, batch_first=True) self.BUFFER = PytorchSeq2SeqWrapper( module) # BiLSTM to encode draft text self.STREAM = nn.LSTMCell(2 * encoder_size, decoder_size) # Store revised text self.BEAM = BeamSearch(self.END, max_steps=max_decoding_steps, beam_size=beam_size) self.U = nn.Sequential(nn.Linear(2 * encoder_size, decoder_size), nn.Tanh()) self.ADD = nn.Sequential(nn.Linear(self.emb_size, encoder_size), nn.Tanh()) self.P = nn.Sequential( nn.Linear(encoder_size + decoder_size, decoder_size), nn.Tanh()) self.W = nn.Linear(decoder_size, self.vocab_size) self.G = nn.Sequential(nn.Linear(decoder_size, 1), nn.Sigmoid()) initializer(self) self._bleu = BLEU( exclude_indices=set(self.SYMBOL)) if use_bleu else None @overrides def forward( self, # type: ignore metadata: List[Dict[str, Any]], triple_tokens: Dict[str, torch.LongTensor], triple_token_ids: torch.Tensor, predicate_tokens: Dict[str, torch.Tensor], draft_tokens: Dict[str, torch.LongTensor], action_tokens: Dict[str, torch.LongTensor] = None, revised_tokens: Dict[str, torch.LongTensor] = None, action_token_ids: torch.Tensor = None, **kwargs) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ state = self._init_state(triple_tokens, predicate_tokens, draft_tokens, triple_token_ids) if action_tokens: # Initialize Decoder state = self._decoder_init(state) output_dict = self._forward_loss(action_tokens, action_token_ids, state, **kwargs) else: output_dict = {} output_dict["metadata"] = metadata if not self.training: # Re-initialize decoder state = self._decoder_init(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) if revised_tokens and self._bleu: top_k_predictions = output_dict["predictions"] best_actions = top_k_predictions[:, 0] best_predictions = self._action_to_token( best_actions, draft_tokens["tokens"]) gold_tokens = self._extend_gold_tokens( revised_tokens["tokens"], action_tokens["tokens"], triple_token_ids, action_token_ids) self._bleu(best_predictions, gold_tokens) return output_dict def _extend_gold_tokens(self, revised_tokens: torch.Tensor, action_tokens: torch.Tensor, triple_token_ids: torch.Tensor, action_token_ids: torch.Tensor): batch_size, action_length = action_tokens.size() triple_size = triple_token_ids.size(1) expanded_triple_ids = triple_token_ids.unsqueeze(1).expand( batch_size, action_length, triple_size) expanded_revised_ids = action_token_ids.unsqueeze(-1).expand( batch_size, action_length, triple_size) match = expanded_triple_ids == expanded_revised_ids copied = match.sum(-1) > 0 oov = action_tokens == self.OOV mask = (oov & copied).long() first_match = ((match.cumsum(-1) == 1) * match).byte().argmax(-1) new_action_tokens = action_tokens * (1 - mask) + ( first_match.long() + self.vocab_size) * mask increment_mask = ~(new_action_tokens == self.DROP) pointer = revised_tokens.new_zeros((revised_tokens.size(0), )) end_point = ((revised_tokens != 0).sum(dim=1) - 1) for i in range(action_length): act_step, mask_step = new_action_tokens[:, i], mask[:, i].bool() revised_tokens[mask_step.nonzero().squeeze(1), pointer[mask_step]] = act_step[mask_step] pointer[increment_mask[:, i]] += 1 pointer = torch.min(pointer, end_point) return revised_tokens def _action_to_token(self, action_tokens: torch.LongTensor, draft_tokens: torch.LongTensor) -> torch.LongTensor: predicted_pointer = action_tokens.new_zeros((draft_tokens.size(0), 1)) draft_pointer = draft_tokens.new_ones((draft_tokens.size(0), 1)) predicted_tokens = action_tokens.new_full((action_tokens.size()), self.END) for act_step in action_tokens.t(): # KEEP, DELETE, COPY, ADD (other) keep_mask = act_step == self.KEEP drop_mask = act_step == self.DROP add_mask = ~(keep_mask | drop_mask) predicted_tokens.scatter_(1, predicted_pointer, draft_tokens.gather(1, draft_pointer)) predicted_tokens[add_mask] = predicted_tokens[add_mask].scatter( 1, predicted_pointer[add_mask], act_step[add_mask].unsqueeze(1)) draft_pointer[keep_mask | drop_mask] += 1 predicted_pointer[~drop_mask] += 1 return predicted_tokens def _decoder_init(self, state): mean_draft = util.masked_mean(state["encoded_draft"], state["draft_mask"].unsqueeze(-1), 1) mean_triple = util.masked_mean(state["encoded_triple"], state["triple_mask"].unsqueeze(-1), 1) concatenated = torch.cat((mean_draft, mean_triple), dim=-1) batch_size = state["draft_mask"].size(0) zeros = mean_draft.new_zeros((batch_size, self.decoder_size)) state["stream_hidden"], state["stream_context"] = self.U( concatenated), zeros state["draft_pointer"] = state["draft_mask"].new_ones((batch_size, )) action_mask = mean_draft.new_ones((batch_size, self.vocab_size)) action_mask[:, self.PAD] = 0 action_mask[:, self.END] = 0 state["action_mask"] = action_mask return state def _init_state(self, triples: Dict[str, torch.LongTensor], predicate: Dict[str, torch.LongTensor], draft: Dict[str, torch.LongTensor], triple_ids: torch.LongTensor) -> Dict[str, torch.Tensor]: emb_pred = util.masked_mean( self.EMB(predicate), util.get_text_field_mask( predicate, num_wrapping_dims=1, ).unsqueeze(-1), 2) emb_triple = self.EMB(triples) triple_mask = util.get_text_field_mask(triples) flat_triples = torch.cat((emb_triple.flatten(2, 3), emb_pred), dim=-1) encoded_triples = self.FACT_ENCODER(flat_triples) emb_draft = self.EMB(draft) draft_mask = util.get_text_field_mask(draft) end_point = (draft_mask.sum(dim=1) - 1) encoded_draft = self.BUFFER(emb_draft, draft_mask) return { "draft_mask": draft_mask, "triple_mask": triple_mask, "end_point": end_point, "encoded_triple": encoded_triples, "encoded_draft": encoded_draft, "triple_tokens": triples["tokens"][:, :, -1], "triple_token_ids": triple_ids } def _forward_loss( self, target_actions: Dict[str, torch.LongTensor], target_token_ids: torch.Tensor, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size, target_sequence_length = target_actions["tokens"].size() num_decoding_steps = target_sequence_length - 1 target_to_triple = state["triple_mask"].new_zeros( state["triple_mask"].size()).bool() copy_input_choice = state["triple_mask"].new_full((batch_size, ), self.COPY) step_log_likelihoods = [] for t in range(num_decoding_steps): input_actions = target_actions["tokens"][:, t] if t < num_decoding_steps - 1: copied = (target_to_triple.sum(dim=-1) > 0) & (input_actions == self.OOV) target_to_triple = state[ "triple_token_ids"] == target_token_ids[:, t + 1].unsqueeze(-1) input_actions = copied.long() * (copy_input_choice - input_actions) + input_actions state = self._decoder_step(input_actions, state) step_target_actions = target_actions["tokens"][:, t + 1] step_log_likelihoods.append( self._get_log_likelihood(state, step_target_actions, target_to_triple)) log_likelihoods = torch.stack(step_log_likelihoods, dim=-1) target_mask = util.get_text_field_mask(target_actions) target_mask = target_mask[:, 1:].float() log_likelihood = (log_likelihoods * target_mask).sum(dim=-1) loss = -log_likelihood.sum() loss /= batch_size return {"loss": loss} @staticmethod def _get_query(state: Dict[str, torch.Tensor]): batch_size = state["encoded_draft"].size(0) buffer_head = state["encoded_draft"][torch.arange(batch_size), state["draft_pointer"]] query = torch.cat([buffer_head, state["stream_hidden"]], dim=1) return query def _get_log_likelihood(self, state: Dict[str, torch.Tensor], target_actions: torch.Tensor, target_to_source: torch.Tensor) -> torch.Tensor: hidden = self.P(self._get_query(state)) gate_prob = self.G(hidden).squeeze(1) gen_prob = util.masked_softmax(self.W(hidden), state["action_mask"], memory_efficient=True)\ .gather(1, target_actions.unsqueeze(1)).squeeze(1) gen_mask = (target_actions != self.OOV) | (target_to_source.sum(dim=-1) == 0) gen_prob = gen_prob.min(gen_mask.float()) copy_prob = self.COPY_ATTN(hidden, state["encoded_triple"], state["triple_mask"])\ .masked_fill(~target_to_source, 0.).sum(dim=-1) step_prob = gen_prob * gate_prob + copy_prob * (-gate_prob + 1) step_log_likelihood = step_prob.clamp(1e-30).log() return step_log_likelihood def _decoder_step( self, last_actions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: embed_actions = self.EMB({"tokens": last_actions}) batch_size = embed_actions.size(0) # Update stack given draft pointer information draft_head = state["encoded_draft"][torch.arange(batch_size), state["draft_pointer"]] query = torch.cat([state["stream_hidden"], draft_head], dim=1) attend = self.ATTN(query, state["encoded_triple"], state["triple_mask"]) attended_triple = util.weighted_sum(state["encoded_triple"], attend) is_added = torch.stack([last_actions != tok for tok in self.SYMBOL]).all(dim=0) draft_head[is_added] = self.ADD(embed_actions[is_added]) hs, cs = self.STREAM(torch.cat((draft_head, attended_triple), dim=-1), (state["stream_hidden"], state["stream_context"])) drop_mask = (last_actions != self.DROP).unsqueeze(1).float() hx = drop_mask * hs + (-drop_mask + 1) * state["stream_hidden"] cx = drop_mask * cs + (-drop_mask + 1) * state["stream_context"] state["stream_hidden"], state["stream_context"] = hx, cx # Update Pointer move_forward = ((last_actions == self.KEEP) | (last_actions == self.DROP)).long() state["draft_pointer"] = state["draft_pointer"] + move_forward # Simple masking for pointer state["draft_pointer"] = torch.min(state["draft_pointer"], state["end_point"]) is_ended = state["end_point"] == state["draft_pointer"] state["action_mask"][is_ended, self.KEEP] = 0 state["action_mask"][is_ended, self.DROP] = 0 state["action_mask"][is_ended, self.END] = 1 return state def take_search_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: input_choices = self._get_input(last_predictions) state = self._decoder_step(input_choices, state) final_prob = self._make_prob(state) return final_prob.clamp(1e-30).log(), state def _get_input( self, last_predictions: torch.Tensor, ) -> torch.Tensor: group_size, = last_predictions.size() only_copy_mask = (last_predictions >= self.vocab_size).long() copy_input_choices = only_copy_mask.new_full((group_size, ), self.COPY) input_choices = (copy_input_choices - last_predictions) * only_copy_mask + last_predictions return input_choices def _make_prob(self, state: Dict[str, torch.Tensor]) -> torch.Tensor: triple_token_ids = state["triple_token_ids"] batch_size, triple_length = triple_token_ids.size() hidden = self.P(self._get_query(state)) gate_prob = self.G(hidden) gen_prob = util.masked_softmax(self.W(hidden), state["action_mask"], memory_efficient=True) * gate_prob copy_prob = self.COPY_ATTN(hidden, state["encoded_triple"], state["triple_mask"]) * (-gate_prob + 1) modified_prob_list: List[torch.Tensor] = [] for i in range(triple_length): copy_prob_slice = copy_prob[:, i] token_slice = state["triple_tokens"][:, i] copy_to_add_mask = token_slice != self.OOV copy_to_add = copy_prob_slice.min( copy_to_add_mask.float()).unsqueeze(-1) gen_prob = gen_prob.scatter_add(-1, token_slice.unsqueeze(1), copy_to_add) if i < (triple_length - 1): future_occurrences = ( (triple_token_ids[:, i + 1:] ) == triple_token_ids[:, i].unsqueeze(-1)).float() future_copy_prob = copy_prob[:, i + 1:].min(future_occurrences) copy_prob_slice += future_copy_prob.sum(-1) if i > 0: prev_occurrences = triple_token_ids[:, : i] == triple_token_ids[:, i].unsqueeze( -1 ) duplicate_mask = (prev_occurrences.sum(-1) == 0).float() copy_prob_slice = copy_prob_slice.min(duplicate_mask) left_over_copy_prob = copy_prob_slice.min( (~copy_to_add_mask).float()) modified_prob_list.append(left_over_copy_prob.unsqueeze(-1)) modified_prob_list.insert(0, gen_prob) modified_prob = torch.cat(modified_prob_list, dim=-1) return modified_prob def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["draft_mask"].size(0) start_predictions = state["draft_mask"].new_full((batch_size, ), self.START) all_top_k_predictions, log_probabilities = self.BEAM.search( start_predictions, state, self.take_search_step) return { "predicted_log_probs": log_probabilities, "predictions": all_top_k_predictions } def _get_predicted_tokens(self, predicted_indices: Union[torch.Tensor, numpy.ndarray], batch_metadata, n_best: int = None): if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() predicted_tokens = [] for top_k_predictions, metadata in zip(predicted_indices, batch_metadata): batch_predicted_tokens = [] draft, triple = metadata['draft'], metadata["triple"] for indices in top_k_predictions[:n_best]: pointer, tokens = 0, [] indices = list(indices) if self.END in indices: indices = indices[:indices.index(self.END)] for index in indices: if index == self.KEEP: tokens.append(draft[pointer]) pointer += 1 elif index == self.DROP: pointer += 1 elif index >= self.vocab_size: adjusted_index = index - self.vocab_size tokens.append(triple[adjusted_index]) else: tokens.append( str(self.vocab.get_token_from_index(index))) batch_predicted_tokens.append(tokens) if n_best == 1: predicted_tokens.append(batch_predicted_tokens[0]) else: predicted_tokens.append(batch_predicted_tokens) return predicted_tokens @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: predicted_tokens = self._get_predicted_tokens( output_dict["predictions"], output_dict["metadata"]) output_dict["predicted_tokens"] = predicted_tokens return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu and not self.training: all_metrics.update(self._bleu.get_metric(reset=reset)) return all_metrics
class DropSeq2Seq(Model): """ This ``SimpleSeq2Seq`` class is a :class:`Model` which takes a sequence, encodes it, and then uses the encoded representations to decode another sequence. You can use this as the basis for a neural machine translation system, an abstractive summarization system, or any other common seq2seq problem. The model here is simple, but should be a decent starting place for implementing recent models for these tasks. Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (`tokens`) or the target tokens can have a different namespace, in which case it needs to be specified as `target_namespace`. source_embedder : ``TextFieldEmbedder``, required Embedder for source side sequences encoder : ``Seq2SeqEncoder``, required The encoder of the "encoder/decoder" model max_decoding_steps : ``int`` Maximum length of decoded sequences. target_namespace : ``str``, optional (default = 'target_tokens') If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. target_embedding_dim : ``int``, optional (default = source_embedding_dim) You can specify an embedding dimensionality for the target side. If not, we'll use the same value as the source embedder's. attention : ``Attention``, optional (default = None) If you want to use attention to get a dynamic summary of the encoder outputs at each step of decoding, this is the function used to compute similarity between the decoder hidden state and encoder outputs. attention_function: ``SimilarityFunction``, optional (default = None) This is if you want to use the legacy implementation of attention. This will be deprecated since it consumes more memory than the specialized attention modules. beam_size : ``int``, optional (default = None) Width of the beam for beam search. If not specified, greedy decoding is used. scheduled_sampling_ratio : ``float``, optional (default = 0.) At each timestep during training, we sample a random number between 0 and 1, and if it is not less than this value, we use the ground truth labels for the whole batch. Else, we use the predictions from the previous time step for the whole batch. If this value is 0.0 (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not using target side ground truth labels. See the following paper for more information: `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al., 2015 <https://arxiv.org/abs/1506.03099>`_. use_bleu : ``bool``, optional (default = True) If True, the BLEU metric will be calculated during validation. """ def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, max_decoding_steps: int, attention: Attention = None, attention_function: SimilarityFunction = None, beam_size: int = None, target_namespace: str = "tokens", target_embedding_dim: int = None, scheduled_sampling_ratio: float = 0., use_bleu: bool = True, emb_dropout: float = 0.5, dec_dropout: float = 0.5) -> None: super(DropSeq2Seq, self).__init__(vocab) self._target_namespace = target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={pad_index, self._end_index, self._start_index}) else: self._bleu = None self._token_based_metric = TokenSequenceAccuracy() # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source vocab tokens. self._source_embedder = source_embedder self._emb_dropout = Dropout(p=emb_dropout) self._dec_dropout = Dropout(p=dec_dropout) # Encodes the sequence of source embeddings into a sequence of hidden states. self._encoder = encoder num_classes = self.vocab.get_vocab_size(self._target_namespace) # Attention mechanism applied to the encoder output for each step. if attention: if attention_function: raise ConfigurationError("You can only specify an attention module or an " "attention function, but not both.") self._attention = attention elif attention_function: self._attention = LegacyAttention(attention_function) else: self._attention = None # Dense embedding of vocab words in the target space. target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim() self._target_embedder = Embedding(num_classes, target_embedding_dim) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = self._encoder.get_output_dim() self._decoder_output_dim = self._encoder_output_dim if self._attention: # If using attention, a weighted average over encoder outputs will be concatenated # to the previous target embedding to form the input to the decoder at each # time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim else: # Otherwise, the input to the decoder is just the previous target embedding. self._decoder_input_dim = target_embedding_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. # TODO (pradeep): Do not hardcode decoder cell type. self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) def take_step(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections(last_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def forward(self, # type: ignore source_tokens: Dict[str, torch.LongTensor], metadata: List[Dict[str, Any]], target_tokens: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Make foward pass with decoder logic for producing the entire target sequence. Parameters ---------- source_tokens : ``Dict[str, torch.LongTensor]`` The output of `TextField.as_array()` applied on the source `TextField`. This will be passed through a `TextFieldEmbedder` and then through an encoder. target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) Output of `Textfield.as_array()` applied on target `TextField`. We assume that the target tokens are also represented as a `TextField`. Returns ------- Dict[str, torch.Tensor] """ state = self._encode(source_tokens) if target_tokens: state = self._init_decoder_state(state) # The `_forward_loop` decodes the input sequence and computes the loss during training # and validation. output_dict = self._forward_loop(state, target_tokens) else: output_dict = {} if not self.training: state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) if target_tokens: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] predicted_tokens = self.decode(output_dict)["predicted_tokens"] self._token_based_metric(predicted_tokens, [x["target_tokens"] for x in metadata]) if self._bleu: self._bleu(best_predictions, target_tokens["tokens"]) return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [self.vocab.get_token_from_index(x, namespace=self._target_namespace) for x in indices] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _encode(self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_input = self._source_embedder(source_tokens) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = self._encoder(embedded_input, source_mask) encoder_outputs = self._emb_dropout(encoder_outputs) return { "source_mask": source_mask, "encoder_outputs": encoder_outputs, } def _init_decoder_state(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( state["encoder_outputs"], state["source_mask"], self._encoder.is_bidirectional()) # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = final_encoder_output # shape: (batch_size, decoder_output_dim) state["decoder_context"] = state["encoder_outputs"].new_zeros(batch_size, self._decoder_output_dim) return state def _forward_loop(self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size()[0] if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full((batch_size,), fill_value=self._start_index) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size,) input_choices = last_predictions elif not target_tokens: # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] # shape: (batch_size, num_classes) output_projections, state = self._prepare_output_projections(input_choices, state) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape: (batch_size, num_classes) class_probabilities = F.softmax(output_projections, dim=-1) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(class_probabilities, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) output_dict = {"predictions": predictions} if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) output_dict["target_mask"] = target_mask loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss return output_dict def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full((batch_size,), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _prepare_output_projections(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # pylint: disable=line-too-long """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (group_size, decoder_output_dim) decoder_context = state["decoder_context"] # shape: (group_size, target_embedding_dim) embedded_input = self._target_embedder(last_predictions) if self._attention: # shape: (group_size, encoder_output_dim) attended_input = self._prepare_attended_input(decoder_hidden, encoder_outputs, source_mask) # shape: (group_size, decoder_output_dim + target_embedding_dim) decoder_input = torch.cat((attended_input, embedded_input), -1) else: # shape: (group_size, target_embedding_dim) decoder_input = embedded_input decoder_input = self._dec_dropout(decoder_input) # shape (decoder_hidden): (batch_size, decoder_output_dim) # shape (decoder_context): (batch_size, decoder_output_dim) decoder_hidden, decoder_context = self._decoder_cell( decoder_input, (decoder_hidden, decoder_context)) state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context # shape: (group_size, num_classes) output_projections = self._output_projection_layer(self._dec_dropout(decoder_hidden)) return output_projections, state def _prepare_attended_input(self, decoder_hidden_state: torch.LongTensor = None, encoder_outputs: torch.LongTensor = None, encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor: """Apply attention over encoder outputs and decoder state.""" # Ensure mask is also a FloatTensor. Or else the multiplication within # attention will complain. # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs_mask = encoder_outputs_mask.float() # shape: (batch_size, max_input_sequence_length) input_weights = self._attention( decoder_hidden_state, encoder_outputs, encoder_outputs_mask) # shape: (batch_size, encoder_output_dim) attended_input = util.weighted_sum(encoder_outputs, input_weights) return attended_input @staticmethod def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if not self.training: all_metrics.update(self._token_based_metric.get_metric(reset=reset)) if self._bleu: all_metrics.update(self._bleu.get_metric(reset=reset)) return all_metrics
class RecombinationSeq2SeqWithCopy(Model): def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, max_decoding_steps: int, seq_metrics: Metric, attention: Attention, beam_size: int = None, source_namespace: str = 'source_tokens', target_namespace: str = "tokens", target_embedding_dim: int = None, scheduled_sampling_ratio: float = 0., use_bleu: bool = False, encoder_input_dropout: int = 0.0, encoder_output_dropout: int = 0.0, dropout=0.0, feed_output_attention_to_decoder: bool = False, keep_decoder_output_dim_same_as_encoder: bool = True, initializer: InitializerApplicator = InitializerApplicator()) -> None: super(RecombinationSeq2SeqWithCopy, self).__init__(vocab) self._source_namespace = source_namespace self._target_namespace = target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) self._pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access # Evaluation Metrics if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={pad_index, self._end_index, self._start_index}) else: self._bleu = None self._seq_metric = seq_metrics # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source vocab tokens. self._source_embedder = source_embedder # Encoder # Encodes the sequence of source embeddings into a sequence of hidden states. self._encoder = encoder self._encoder_output_dim = self._encoder.get_output_dim() # Attention mechanism applied to the encoder output for each step. self._attention = attention self._feed_output_attention_to_decoder = feed_output_attention_to_decoder if self._feed_output_attention_to_decoder: # If using attention, a weighted average over encoder outputs will be concatenated # to the previous target embedding to form the input to the decoder at each # time step. self._decoder_input_dim = self._encoder_output_dim + target_embedding_dim else: # Otherwise, the input to the decoder is just the previous target embedding. self._decoder_input_dim = target_embedding_dim # Decoder # Dense embedding of vocab words in the target space. num_classes = self.vocab.get_vocab_size(self._target_namespace) self._num_classes = num_classes target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim() self._target_embedder = Embedding(num_classes, target_embedding_dim) # TODO: relax this assumption # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._keep_decoder_output_dim_same_as_encoder = keep_decoder_output_dim_same_as_encoder if not self._keep_decoder_output_dim_same_as_encoder: self._decoder_output_dim = int(self._encoder_output_dim / 2) if encoder.is_bidirectional() \ else self._encoder_output_dim else: self._decoder_output_dim = self._encoder_output_dim self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) self._transform_decoder_init_state = torch.nn.Sequential( torch.nn.Linear(self._encoder_output_dim, self._decoder_output_dim), torch.nn.Tanh() ) # Generate Score self._output_projection_layer = Linear(self._decoder_output_dim + self._encoder_output_dim, num_classes) # Dropout Layers self._encoder_input_dropout = torch.nn.Dropout(p=encoder_input_dropout) self._encoder_output_dropout = torch.nn.Dropout(p=encoder_output_dropout) self._output_dropout = torch.nn.Dropout(p=dropout) self._embedded_dropout = torch.nn.Dropout(p=dropout) initializer(self) def _prepare_output_projections(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor])\ -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # pylint: disable=line-too-long """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Add dropout before the softmax classifier (Following "Language to Logical Form with Neural Attention") Inputs are the same as for `take_step()`. last_predictions: (group_size,) """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (group_size, decoder_output_dim) decoder_context = state["decoder_context"] # shape: (group_size, target_embedding_dim) copy_mask = (last_predictions < self._num_classes).long() embedded_input = self._target_embedder(last_predictions * copy_mask) if not self.training and copy_mask.sum() < copy_mask.size(0): # Copy, Retrieve target token mapped_indices = list() source_token_ids = state['source_token_ids'] for gidx, idx in enumerate(last_predictions): if idx >= self._num_classes: source_idx = idx - self._num_classes source_token_id = int(source_token_ids[gidx,source_idx]) token = self.vocab.get_token_from_index(source_token_id, self._source_namespace) tid = self.vocab.get_token_index(token, self._target_namespace) mapped_indices.append(tid) else: mapped_indices.append(self._pad_index) # mapped_indices to tensor mapped_indices = torch.from_numpy(numpy.array(mapped_indices)) mapped_indices = mapped_indices.to(last_predictions.device) copyed_embedded_input = self._target_embedder(mapped_indices) unsqueezed_copy_mask = copy_mask.unsqueeze(dim=1).float() embedded_input = embedded_input * unsqueezed_copy_mask + copyed_embedded_input * (1 - unsqueezed_copy_mask) embedded_input = self._embedded_dropout(embedded_input) if self._feed_output_attention_to_decoder: # shape: (group_size, decoder_output_dim + target_embedding_dim) decoder_input = torch.cat((embedded_input, state["attention_context"]), -1) else: # shape: (group_size, target_embedding_dim) decoder_input = embedded_input # shape (decoder_hidden): (group_size, decoder_output_dim) # shape (decoder_context): (group_size, decoder_output_dim) decoder_hidden, decoder_context = self._decoder_cell( decoder_input, (decoder_hidden, decoder_context)) state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context # output_attended_input: shape: (group_size, encoder_output_dim) # attention_weights shape: (group_size, max_input_sequence_length) output_attended_input, attention_weights = self._prepare_output_attended_input( decoder_hidden, encoder_outputs, source_mask ) if self._feed_output_attention_to_decoder: state["attention_context"] = output_attended_input output_projection_input = torch.cat((decoder_hidden, output_attended_input), -1) dropped_output_projection_input = self._output_dropout(output_projection_input) # shape: (group_size, num_classes) output_projections = self._output_projection_layer(dropped_output_projection_input) # shape: (group_size, num_classes + max_input_sequence_length) output_projections = torch.cat((output_projections, attention_weights), -1) return output_projections, state def take_step(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes + max_input_sequence_length) output_projections, state = self._prepare_output_projections(last_predictions, state) source_mask = state['source_mask'] group_size = source_mask.size(0) # (batch_size, num_classes + max_input_sequence_length) normalization_mask = torch.cat([source_mask.new_ones((group_size, self._num_classes)), source_mask], dim=-1) # shape: (group_size, num_classes + max_input_sequence_length) class_log_probabilities = util.masked_log_softmax(output_projections, normalization_mask, dim=-1) return class_log_probabilities, state @overrides def forward(self, # type: ignore source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None, target_source_token_map: torch.Tensor = None, meta_field: List[Dict] = None, ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Make foward pass with decoder logic for producing the entire target sequence. Parameters ---------- source_tokens : ``Dict[str, torch.LongTensor]`` The output of `TextField.as_array()` applied on the source `TextField`. This will be passed through a `TextFieldEmbedder` and then through an encoder. target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) Output of `Textfield.as_array()` applied on target `TextField`. We assume that the target tokens are also represented as a `TextField`. target_source_token_map: (batch_size, target_length, source_length) Returns ------- Dict[str, torch.Tensor] """ state = self._encode(source_tokens) if target_tokens: state = self._init_decoder_state(state) # The `_forward_loop` decodes the input sequence and computes the loss during training # and validation. output_dict = self._forward_loop(state, target_tokens, target_source_token_map) else: output_dict = {} if not self.training: state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) output_dict.update({"source_token_ids": source_tokens['tokens']}) if target_tokens: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = self.map_predictions(top_k_predictions[:, 0, :], source_tokens['tokens'], meta_field) if self._bleu: self._bleu(best_predictions, target_tokens["tokens"]) if self._seq_metric: self._seq_metric( best_predictions.float(), gold_labels=target_tokens["tokens"][:, 1:].float(), mask=util.get_text_field_mask( target_tokens).float()[:, 1:] ) return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = list() for x in indices: if x < self._num_classes: predicted_tokens.append(self.vocab.get_token_from_index(x, namespace=self._target_namespace)) else: source_idx = x - self._num_classes text = "@@copy@@%d" % int(source_idx) token = Token(text) # source_token_id = int(output_dict['source_token_ids'][0][source_idx]) # token = self.vocab.get_token_from_index(source_token_id, self._source_namespace) predicted_tokens.append(token) all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _encode(self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_input = self._source_embedder(source_tokens) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) embedded_input = self._encoder_input_dropout(embedded_input) encoder_outputs = self._encoder(embedded_input, source_mask) encoder_outputs = self._encoder_output_dropout(encoder_outputs) return { "source_token_ids": source_tokens['tokens'], "source_mask": source_mask, "encoder_outputs": encoder_outputs, } def _init_decoder_state(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( state["encoder_outputs"], state["source_mask"], self._encoder.is_bidirectional()) # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = self._transform_decoder_init_state(final_encoder_output) # shape: (batch_size, decoder_output_dim) state["decoder_context"] = state["encoder_outputs"].new_zeros(batch_size, self._decoder_output_dim) if self._feed_output_attention_to_decoder: state["attention_context"] = state["encoder_outputs"].new_zeros(batch_size, self._encoder_output_dim) return state def _forward_loop(self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None, target_source_token_map: torch.Tensor = None ) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size()[0] if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full((batch_size,), fill_value=self._start_index) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size,) input_choices = last_predictions elif not target_tokens: # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] # shape: (batch_size, num_classes + max_input_sequence_length) output_projections, state = self._prepare_output_projections(input_choices, state) # list of tensors, shape: (batch_size, 1, num_classes + max_input_sequence_length) step_logits.append(output_projections.unsqueeze(1)) # (batch_size, num_classes + max_input_sequence_length) normalization_mask = torch.cat([source_mask.new_ones((batch_size, self._num_classes)), source_mask], dim=-1) class_probabilities = util.masked_softmax(output_projections, normalization_mask, dim=-1) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(class_probabilities, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) output_dict = {"predictions": predictions} if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes + max_input_sequence_length) logits = torch.cat(step_logits, 1) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask, target_source_token_map) output_dict["loss"] = loss return output_dict def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full((batch_size,), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _prepare_output_attended_input(self, decoder_hidden_state: torch.Tensor = None, encoder_outputs: torch.Tensor = None, encoder_outputs_mask: torch.LongTensor = None) \ -> Tuple[torch.Tensor, torch.Tensor]: """Apply ouput attention over encoder outputs and decoder state.""" # Ensure mask is also a FloatTensor. Or else the multiplication within # attention will complain. # shape: (batch_size, max_input_sequence_length) encoder_outputs_mask = encoder_outputs_mask.float() # shape: (batch_size, max_input_sequence_length) input_weights = self._attention( decoder_hidden_state, encoder_outputs, encoder_outputs_mask) normalized_weights = util.masked_softmax(input_weights, encoder_outputs_mask) # shape: (batch_size, encoder_output_dim) attended_input = util.weighted_sum(encoder_outputs, normalized_weights) return attended_input, input_weights def _get_loss(self, logits: torch.FloatTensor, targets: torch.LongTensor, target_mask: torch.LongTensor, target_source_token_map: torch.Tensor) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. ``target_source_token_map``: (batch_size, target_length, source_length) During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() batch_size, num_decoding_steps = relevant_targets.size() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps, source_length) target_source_token_map = target_source_token_map[:, 1:, :] probs = F.softmax(logits, dim=-1) # (batch_size * num_decoding_steps, num_classes) generate_probs_flat = probs[:, :, :self._num_classes].view(-1, self._num_classes) relevant_targets_flat = relevant_targets.view(-1, 1).long() # (batch_size, num_decoding_steps) generate_probs = torch.gather(generate_probs_flat, dim=1, index=relevant_targets_flat).reshape(batch_size, num_decoding_steps) # (batch_size, num_decoding_steps) copy_probs = (probs[:, :, self._num_classes:] * target_source_token_map).sum(dim=-1) target_log_probs = torch.log(generate_probs + copy_probs + 1e-13) target_log_probs *= relevant_mask.float() negative_log_likelihood = -1 * target_log_probs weights_batch_sum = relevant_mask.sum(-1).float() per_batch_loss = negative_log_likelihood.sum(dim=1) / (weights_batch_sum + 1e-13) num_non_empty_sequences = ((weights_batch_sum > 0).float().sum() + 1e-13) return per_batch_loss.sum() / num_non_empty_sequences @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if not self.training: if self._bleu: all_metrics.update(self._bleu.get_metric(reset=reset)) if self._seq_metric: all_metrics.update( {"accuracy": self._seq_metric.get_metric(reset)['accuracy']}) return all_metrics def map_predictions(self, predictions: torch.LongTensor, source_token_ids: torch.LongTensor, meta_field: List[Dict]) -> torch.LongTensor: """ Map those copy indices to target idx :return: """ batch_size, max_length = predictions.size() mapped_predictions = predictions.new_full((batch_size,max_length), fill_value=self._pad_index) for i in range(batch_size): source_tokens_to_copy = meta_field[i]['source_tokens_to_copy'] for j in range(max_length): idx = predictions[i, j] if idx < self._num_classes: mapped_predictions[i, j] = idx else: # Copy source_idx = idx - self._num_classes if source_idx > len(source_tokens_to_copy): tid = self._pad_index else: token = source_tokens_to_copy[source_idx] # source_token_id = int(source_token_ids[i, source_idx]) # token = self.vocab.get_token_from_index(source_token_id, self._source_namespace) tid = self.vocab.get_token_index(token, self._target_namespace) mapped_predictions[i, j] = tid return mapped_predictions.long()
class Event2Mind(Model): """ This ``Event2Mind`` class is a :class:`Model` which takes an event sequence, encodes it, and then uses the encoded representation to decode several mental state sequences. It is based on `the paper by Rashkin et al. <https://www.semanticscholar.org/paper/Event2Mind/b89f8a9b2192a8f2018eead6b135ed30a1f2144d>`_ Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (``tokens``) or the target tokens can have a different namespace, in which case it needs to be specified as ``target_namespace``. source_embedder : ``TextFieldEmbedder``, required Embedder for source side sequences. embedding_dropout: float, required The amount of dropout to apply after the source tokens have been embedded. encoder : ``Seq2VecEncoder``, required The encoder of the "encoder/decoder" model. max_decoding_steps : int, required Length of decoded sequences. beam_size : int, optional (default = 10) The width of the beam search. target_names: ``List[str]``, optional, (default = ['xintent', 'xreact', 'oreact']) Names of the target fields matching those in the ``Instance`` objects. target_namespace : str, optional (default = 'tokens') If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. target_embedding_dim : int, optional (default = source_embedding_dim) You can specify an embedding dimensionality for the target side. If not, we'll use the same value as the source embedder's. """ def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, embedding_dropout: float, encoder: Seq2VecEncoder, max_decoding_steps: int, beam_size: int = 10, target_names: List[str] = None, target_namespace: str = "tokens", target_embedding_dim: int = None) -> None: super().__init__(vocab) target_names = target_names or ["xintent", "xreact", "oreact"] # Note: The original tweaks the embeddings for "personx" to be the mean # across the embeddings for "he", "she", "him" and "her". Similarly for # "personx's" and so forth. We could consider that here as a well. self._source_embedder = source_embedder self._embedding_dropout = nn.Dropout(embedding_dropout) self._encoder = encoder self._max_decoding_steps = max_decoding_steps self._target_namespace = target_namespace # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) # Warning: The different decoders share a vocabulary! This may be # counterintuitive, but consider the case of xreact and oreact. A # reaction of "happy" could easily apply to both the subject of the # event and others. This could become less appropriate as more decoders # are added. num_classes = self.vocab.get_vocab_size(self._target_namespace) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with that of the final hidden states of the encoder. self._decoder_output_dim = self._encoder.get_output_dim() target_embedding_dim = target_embedding_dim or self._source_embedder.get_output_dim( ) self._states = ModuleDict() for name in target_names: self._states[name] = StateDecoder(num_classes, target_embedding_dim, self._decoder_output_dim) self._beam_search = BeamSearch(self._end_index, beam_size=beam_size, max_steps=max_decoding_steps) def _update_recall(self, all_top_k_predictions: torch.Tensor, target_tokens: Dict[str, torch.LongTensor], target_recall: UnigramRecall) -> None: targets = target_tokens["tokens"] target_mask = get_text_field_mask(target_tokens) # See comment in _get_loss. # TODO(brendanr): Do we need contiguous here? relevant_targets = targets[:, 1:].contiguous() relevant_mask = target_mask[:, 1:].contiguous() target_recall(all_top_k_predictions, relevant_targets, relevant_mask, self._end_index) def _get_num_decoding_steps( self, target_tokens: Optional[Dict[str, torch.LongTensor]]) -> int: if target_tokens: targets = target_tokens["tokens"] target_sequence_length = targets.size()[1] # The last input from the target is either padding or the end # symbol. Either way, we don't have to process it. (To be clear, # we do still output and compare against the end symbol, but there # is no need to take the end symbol as input to the decoder.) return target_sequence_length - 1 else: return self._max_decoding_steps @overrides def forward( self, # type: ignore source: Dict[str, torch.LongTensor], **target_tokens: Dict[str, Dict[str, torch.LongTensor]] ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing the target sequences. Parameters ---------- source : ``Dict[str, torch.LongTensor]`` The output of ``TextField.as_array()`` applied on the source ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. target_tokens : ``Dict[str, Dict[str, torch.LongTensor]]``: Dictionary from name to output of ``Textfield.as_array()`` applied on target ``TextField``. We assume that the target tokens are also represented as a ``TextField``. """ # (batch_size, input_sequence_length, embedding_dim) embedded_input = self._embedding_dropout(self._source_embedder(source)) source_mask = get_text_field_mask(source) # (batch_size, encoder_output_dim) final_encoder_output = self._encoder(embedded_input, source_mask) output_dict = {} # Perform greedy search so we can get the loss. if target_tokens: if target_tokens.keys() != self._states.keys(): target_only = target_tokens.keys() - self._states.keys() states_only = self._states.keys() - target_tokens.keys() raise Exception( "Mismatch between target_tokens and self._states. Keys in " + f"targets only: {target_only} Keys in states only: {states_only}" ) total_loss = 0 for name, state in self._states.items(): loss = self.greedy_search( final_encoder_output=final_encoder_output, target_tokens=target_tokens[name], target_embedder=state.embedder, decoder_cell=state.decoder_cell, output_projection_layer=state.output_projection_layer) total_loss += loss output_dict[f"{name}_loss"] = loss # Use mean loss (instead of the sum of the losses) to be comparable to the paper. output_dict["loss"] = total_loss / len(self._states) # Perform beam search to obtain the predictions. if not self.training: batch_size = final_encoder_output.size()[0] for name, state in self._states.items(): start_predictions = final_encoder_output.new_full( (batch_size, ), fill_value=self._start_index, dtype=torch.long) start_state = {"decoder_hidden": final_encoder_output} # (batch_size, 10, num_decoding_steps) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, start_state, state.take_step) if target_tokens: self._update_recall(all_top_k_predictions, target_tokens[name], state.recall) output_dict[ f"{name}_top_k_predictions"] = all_top_k_predictions output_dict[ f"{name}_top_k_log_probabilities"] = log_probabilities return output_dict def greedy_search(self, final_encoder_output: torch.LongTensor, target_tokens: Dict[str, torch.LongTensor], target_embedder: Embedding, decoder_cell: GRUCell, output_projection_layer: Linear) -> torch.FloatTensor: """ Greedily produces a sequence using the provided ``decoder_cell``. Returns the cross entropy between this sequence and ``target_tokens``. Parameters ---------- final_encoder_output : ``torch.LongTensor``, required Vector produced by ``self._encoder``. target_tokens : ``Dict[str, torch.LongTensor]``, required The output of ``TextField.as_array()`` applied on some target ``TextField``. target_embedder : ``Embedding``, required Used to embed the target tokens. decoder_cell: ``GRUCell``, required The recurrent cell used at each time step. output_projection_layer: ``Linear``, required Linear layer mapping to the desired number of classes. """ num_decoding_steps = self._get_num_decoding_steps(target_tokens) targets = target_tokens["tokens"] decoder_hidden = final_encoder_output step_logits = [] for timestep in range(num_decoding_steps): # See https://github.com/allenai/allennlp/issues/1134. input_choices = targets[:, timestep] decoder_input = target_embedder(input_choices) decoder_hidden = decoder_cell(decoder_input, decoder_hidden) # (batch_size, num_classes) output_projections = output_projection_layer(decoder_hidden) # list of (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) target_mask = get_text_field_mask(target_tokens) return self._get_loss(logits, targets, target_mask) def greedy_predict(self, final_encoder_output: torch.LongTensor, target_embedder: Embedding, decoder_cell: GRUCell, output_projection_layer: Linear) -> torch.Tensor: """ Greedily produces a sequence using the provided ``decoder_cell``. Returns the predicted sequence. Parameters ---------- final_encoder_output : ``torch.LongTensor``, required Vector produced by ``self._encoder``. target_embedder : ``Embedding``, required Used to embed the target tokens. decoder_cell: ``GRUCell``, required The recurrent cell used at each time step. output_projection_layer: ``Linear``, required Linear layer mapping to the desired number of classes. """ num_decoding_steps = self._max_decoding_steps decoder_hidden = final_encoder_output batch_size = final_encoder_output.size()[0] predictions = [ final_encoder_output.new_full((batch_size, ), fill_value=self._start_index, dtype=torch.long) ] for _ in range(num_decoding_steps): input_choices = predictions[-1] decoder_input = target_embedder(input_choices) decoder_hidden = decoder_cell(decoder_input, decoder_hidden) # (batch_size, num_classes) output_projections = output_projection_layer(decoder_hidden) class_probabilities = F.softmax(output_projections, dim=-1) _, predicted_classes = torch.max(class_probabilities, 1) predictions.append(predicted_classes) all_predictions = torch.cat([ps.unsqueeze(1) for ps in predictions], 1) # Drop start symbol and return. return all_predictions[:, 1:] @staticmethod def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.FloatTensor: """ Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ relevant_targets = targets[:, 1:].contiguous( ) # (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous( ) # (batch_size, num_decoding_steps) loss = sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) return loss def decode_all(self, predicted_indices: torch.Tensor) -> List[List[str]]: if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) return all_predicted_tokens @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, List[List[str]]]: """ This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds fields for the tokens to the ``output_dict``. """ for name in self._states: top_k_predicted_indices = output_dict[f"{name}_top_k_predictions"][ 0] output_dict[f"{name}_top_k_predicted_tokens"] = [ self.decode_all(top_k_predicted_indices) ] return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics = {} # Recall@10 needs beam search which doesn't happen during training. if not self.training: for name, state in self._states.items(): all_metrics[name] = state.recall.get_metric(reset=reset) return all_metrics
class WNCGTransformerModel(WNCGBaseModel): def __init__(self, input_gpv_dim, d_model, nhead, num_layers, n_vocab, input_amedas_seqlen, weather, cl, lr=0.001, dropout_p=0.2, warm_up_steps=4000): super().__init__(d_model, weather, cl, dropout_p) """ Encoder """ # encoder for gpv self.gpv_encoder = MLPEncoder(input_gpv_dim, d_model, dropout_p=dropout_p) self.pos_encoder = PositionalEncoding(d_model) # encoder for amedas self.amedas_to_dmodel = nn.Linear(input_amedas_seqlen, d_model) # encoder for meta-data metaenc = {} metaenc["area"] = nn.Embedding(277, d_model) metaenc["month"] = nn.Embedding(12, d_model) metaenc["day"] = nn.Embedding(31, d_model) metaenc["time"] = nn.Embedding(24, d_model) metaenc["week"] = nn.Embedding(7, d_model) self.meta_encoders = nn.ModuleDict(metaenc) # encoder self.transformer_encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead), num_layers=num_layers) """ Decoder """ # word decoder self.token_embedder = nn.Embedding(n_vocab, d_model, padding_idx=IDs.PAD.value) self.token_position = PositionalEncoding(d_model) self.token_decoder = nn.TransformerDecoder(nn.TransformerDecoderLayer( d_model=d_model, nhead=nhead), num_layers=num_layers) self.token_output = nn.Linear(d_model, n_vocab) # make the arguments global self.lr = lr # weather label self.weather = weather # content agreement loss self.cl = cl # warm up steps for learning rate self.warm_up_steps = warm_up_steps # save the arguments self.save_hyperparameters() def generate_square_subsequent_mask(self, sz): mask = torch.triu(torch.ones(sz, sz), 1) mask = mask.masked_fill(mask == 1, float('-inf')) return mask def encode(self, src_gpv, src_amedas, src_meta, src_comment): """ encode """ # encode gpv-data src_gpv = self.gpv_encoder(src_gpv) src_gpv = self.pos_encoder(src_gpv) # encode amedas-data src_amedas = self.amedas_to_dmodel(src_amedas) # encode meta-data emb_area = self.meta_encoders["area"](src_meta[0, :]) emb_month = self.meta_encoders["month"](src_meta[1, :]) emb_day = self.meta_encoders["day"](src_meta[2, :]) emb_time = self.meta_encoders["time"](src_meta[3, :]) emb_week = self.meta_encoders["week"](src_meta[4, :]) src_meta = torch.stack( [emb_area, emb_month, emb_day, emb_time, emb_week], dim=0) # concatenate input-data (gpv, amedas, meta) src_data = torch.cat( [src_gpv, src_amedas, src_meta], dim=0 ) # (seq_len[9(gpv) + 4(amedas) + 5(meta)], batch_size, d_model) # encode input-data by transformer src_memory = self.transformer_encoder( src_data) # (seq_len, batch_size, d_model) return src_gpv, src_amedas, src_meta, src_memory def forward(self, src_gpv, src_amedas, src_meta, src_comment): """[summary] Args: src_gpv ([type]): [description] src_amedas ([type]): [description] src_meta ([type]): [description] src_comment ([type]): [description] Returns: [type]: [description] """ """ encode gpv/amedas/meta """ src_gpv, src_amedas, src_meta, src_memory = \ self.encode(src_gpv, src_amedas, src_meta, src_comment) # initialize outputs of weather labels and weather hidden ZERO = torch.zeros(1, 1).to(self.device) sunny_out, cloudy_out, rain_out, snow_out, weather_hidden = \ ZERO, ZERO, ZERO, ZERO, ZERO, None """ decode weather labels """ if self.weather == "label": sunny_out, sunny_hidden = self.sunny_decoder(src_memory[0]) cloudy_out, cloudy_hidden = self.cloudy_decoder(src_memory[0]) rain_out, rain_hidden = self.rain_decoder(src_memory[0]) snow_out, snow_hidden = self.snow_decoder(src_memory[0]) weather_hidden = torch.stack( [sunny_hidden, cloudy_hidden, rain_hidden, snow_hidden], dim=0) # induce weather_hidden into input if self.weather is not None: src_memory = torch.cat([src_memory, weather_hidden], dim=0) """ decode tokens """ # prepare masks for word decoder src_comment_len = src_comment.size(0) # seq_len # mask for padding token src_comment_padd_mask = (src_comment == IDs.PAD.value).transpose( 0, 1).to(self.device) # (batch_size, seq_len) # mask for subsequence src_comment_attn_mask = self.generate_square_subsequent_mask( src_comment_len).to(self.device) # (seq_len, seq_len) # embedding src_comment_emb = self.token_embedder( src_comment) # (seqlen, batch_size, d_model) src_comment_emb_pos = self.token_position(src_comment_emb) # decode token_hidden = self.token_decoder( src_comment_emb_pos, src_memory, tgt_mask=src_comment_attn_mask, tgt_key_padding_mask=src_comment_padd_mask ) # (seqlen, batch_size, d_model) # output distribution over vocabularies token_out = self.token_output(token_hidden) return (F.log_softmax(token_out, dim=-1), \ F.log_softmax(sunny_out, dim=-1), F.log_softmax(cloudy_out, dim=-1), \ F.log_softmax(rain_out, dim=-1), F.log_softmax(snow_out, dim=-1), weather_hidden, src_comment_emb) # learning rate warm-up def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs): """[summary] Args: epoch ([type]): [description] batch_idx ([type]): [description] optimizer ([type]): [description] optimizer_idx ([type]): [description] optimizer_closure ([type]): [description] on_tpu ([type]): [description] using_native_amp ([type]): [description] using_lbfgs ([type]): [description] """ # warm up lr for pg in optimizer.param_groups: self.lr = (self.hparams.d_model**-0.5) * min( float(self.trainer.global_step + 1)**-0.5, float(self.trainer.global_step + 1) * self.warm_up_steps**-1.5) pg['lr'] = self.lr # update params optimizer.step(closure=optimizer_closure) optimizer.zero_grad() def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, betas=(0.9, 0.98), eps=1e-09) return optimizer def greedy_token_decode(self, src_memory, gpv_output, amedas_output, meta_output, weather_hidden, token_generation_limit=128): """ decode tokens """ _, batch_size, d_model = src_memory.size() src_comment = torch.tensor([[IDs.BOS.value for _ in range(batch_size)]], dtype=torch.long).to(self.device) decoded_batch = torch.zeros((batch_size, token_generation_limit)) # induce weather_hidden into input if self.weather is not None: src_memory = torch.cat([src_memory, weather_hidden], dim=0) for idx in range(token_generation_limit): # prepare masks for word decoder src_comment_len = idx + 1 # seq_len # mask for padding token src_comment_padd_mask = (src_comment == IDs.PAD.value).transpose( 0, 1).to(self.device) # (batch_size, seq_len) # mask for subsequence src_comment_attn_mask = self.generate_square_subsequent_mask( src_comment_len).to(self.device) # (seq_len, seq_len) # embedding src_comment_emb = self.token_embedder( src_comment) # (seqlen, batch_size, d_model) src_comment_emb = self.token_position(src_comment_emb) # decode token_hidden = self.token_decoder( src_comment_emb, src_memory, tgt_mask=src_comment_attn_mask, tgt_key_padding_mask=src_comment_padd_mask ) # (seqlen, batch_size, d_model) # output distribution over vocabularies token_out = self.token_output(token_hidden) topv, topi = token_out[-1, :, :].data.topk(1) decoded_batch[:, idx] = topi.view(-1) topi = topi.transpose(0, 1) # concat source with output src_comment = torch.cat([src_comment, topi], dim=0) return decoded_batch.detach().tolist() @torch.no_grad() def beam_token_decode(self, src_memory, gpv_output, amedas_output, meta_output, weather_hidden, beam_width=5): max_steps = 128 # The maximum number of decoding steps to take, self.beam_search = BeamSearch(end_index=IDs.EOS.value, max_steps=max_steps, beam_size=beam_width) batch_size = src_memory.size(1) # induce weather_hidden into input if self.weather is not None: src_memory = torch.cat([src_memory, weather_hidden], dim=0) start_predictions = torch.tensor([IDs.BOS.value] * batch_size, dtype=torch.long, device=self.device) start_state = { "prev_tokens": torch.zeros(batch_size, 0, dtype=torch.long, device=self.device), # set none of prev_tokens "decoder_hidden": src_memory # (seq_len, batch_size, d_model) } def step(last_tokens, current_state, t): """ Args: last_tokens: (group_size,) current_state: {} t: int """ # concatenate prev_tokens with last_tokens prev_tokens = torch.cat( [current_state["prev_tokens"], last_tokens.unsqueeze(1)], dim=-1) # [batch_size * beam_width, t+1] # embedding prev_tokens_emb = self.token_embedder(prev_tokens).transpose( 0, 1) # (seqlen, batch_size, d_model) prev_tokens_emb = self.token_position(prev_tokens_emb) prev_tokens_len = prev_tokens.size(1) # mask for padding token prev_token_padd_mask = (prev_tokens == IDs.PAD.value).to( self.device) # (batch_size, seq_len) # mask for subsequence prev_token_attn_mask = self.generate_square_subsequent_mask( prev_tokens_len).to(self.device) # (seq_len, seq_len) # decode token_hidden = self.token_decoder( prev_tokens_emb, current_state["decoder_hidden"], tgt_mask=prev_token_attn_mask, tgt_key_padding_mask=prev_token_padd_mask ) # (seqlen, batch_size, d_model) # output distribution over vocabularies token_out = self.token_output(token_hidden) # get outout distribution for last token decoder_output = F.log_softmax(token_out[-1, :, :], dim=-1) # update prev_tokens current_state["prev_tokens"] = prev_tokens return (decoder_output, current_state) predictions, log_probs = self.beam_search.search( start_predictions=start_predictions, start_state=start_state, step=step) return predictions, log_probs
class MSPointerNetwork(Model): def __init__(self, vocab: Vocabulary, source_embedder_1: TextFieldEmbedder, source_encoder_1: Seq2SeqEncoder, beam_size: int, max_decoding_steps: int, decoder_output_dim: int, target_embedding_dim: int = 30, namespace: str = "tokens", tensor_based_metric: Metric = None, align_embeddings: bool = True, source_embedder_2: TextFieldEmbedder = None, source_encoder_2: Seq2SeqEncoder = None) -> None: super().__init__(vocab) self._source_embedder_1 = source_embedder_1 self._source_embedder_2 = source_embedder_1 or self._source_embedder_1 self._source_encoder_1 = source_encoder_1 self._source_encoder_2 = source_encoder_2 or self._source_encoder_1 self._source_namespace = namespace self._target_namespace = namespace self.encoder_output_dim_1 = self._source_encoder_1.get_output_dim() self.encoder_output_dim_2 = self._source_encoder_2.get_output_dim() self.cated_encoder_out_dim = self.encoder_output_dim_1 + self.encoder_output_dim_2 self.decoder_output_dim = decoder_output_dim # TODO: AllenNLP实现的Addictive Attention可能没有bias self._attention_1 = AdditiveAttention(self.decoder_output_dim, self.encoder_output_dim_1) self._attention_2 = AdditiveAttention(self.decoder_output_dim, self.encoder_output_dim_2) if not align_embeddings: self.target_embedding_dim = target_embedding_dim self._target_vocab_size = self.vocab.get_vocab_size( namespace=self._target_namespace) self._target_embedder = Embedding(self._target_vocab_size, target_embedding_dim) else: self._target_embedder = self._source_embedder_1._token_embedders[ "tokens"] self._target_vocab_size = self.vocab.get_vocab_size( namespace=self._target_namespace) self.target_embedding_dim = self._target_embedder.get_output_dim() self.decoder_input_dim = self.encoder_output_dim_1 + self.encoder_output_dim_2 + \ self.target_embedding_dim self._decoder_cell = LSTMCell(self.decoder_input_dim, self.decoder_output_dim) # 用于将两个encoder的最后隐层状态映射成解码器初始状态 self._encoder_out_projection_layer = torch.nn.Linear( in_features=self.cated_encoder_out_dim, out_features=self.decoder_output_dim ) # TODO: bias - true of false? # 软门控机制参数,用于计算lambda self._gate_projection_layer = torch.nn.Linear( in_features=self.decoder_output_dim + self.decoder_input_dim, out_features=1, bias=False) self._start_index = self.vocab.get_token_index(START_SYMBOL, namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, namespace) self._pad_index = self.vocab.get_token_index(self.vocab._padding_token, namespace) self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) self._tensor_based_metric = tensor_based_metric or \ BLEU(exclude_indices={self._pad_index, self._end_index, self._start_index}) def _encode( self, source_tokens_1: Dict[str, torch.Tensor], source_tokens_2: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ 分别将source1和source2的token ids经过encoder编码,输出各自的mask以及encoder_out。 同时token_ids信息也会附加。 """ # 1. 编码source1 # shape: (batch_size, seq_max_len_1) source_mask_1 = util.get_text_field_mask(source_tokens_1) # shape: (batch_size, seq_max_len_1, encoder_input_dim_1) embedder_out_1 = self._source_embedder_1(source_tokens_1) # shape: (batch_size, seq_max_len_1, encoder_output_dim_1) encoder_out_1 = self._source_encoder_1(embedder_out_1, source_mask_1) # 2. 编码source2 # shape: (batch_size, seq_max_len_2) source_mask_2 = util.get_text_field_mask(source_tokens_2) # shape: (batch_size, seq_max_len_2, encoder_input_dim_2) embedder_out_2 = self._source_embedder_2(source_tokens_2) # shape: (batch_size, seq_max_len_2, encoder_input_dim_2) encoder_out_2 = self._source_encoder_2(embedder_out_2, source_mask_2) return { "source_mask_1": source_mask_1, "source_mask_2": source_mask_2, "source_token_ids_1": source_tokens_1["tokens"], "source_token_ids_2": source_tokens_2["tokens"], "encoder_out_1": encoder_out_1, "encoder_out_2": encoder_out_2, } def _init_decoder_state( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ 初始化decoder:更新传入的state,使之带有decoder的context和hidden向量。 其中hidden向量(h_0)通过两个编码器的最终隐层状态经过一个 映射得到,context初始化为0向量。 """ batch_size = state["encoder_out_1"].size()[0] # 根据每个batch的mask情况,获取最终rnn隐层状态 # shape: (batch_size, encoder_output_dim_1) encoder_final_output_1 = util.get_final_encoder_states( state["encoder_out_1"], state["source_mask_1"], self._source_encoder_1.is_bidirectional()) # shape: (batch_size, encoder_output_dim_2) encoder_final_output_2 = util.get_final_encoder_states( state["encoder_out_2"], state["source_mask_2"], self._source_encoder_2.is_bidirectional()) # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = torch.relu( self._encoder_out_projection_layer( torch.cat([encoder_final_output_1, encoder_final_output_2], dim=-1))) # shape: (batch_size, decoder_output_dim) state["decoder_context"] = state["decoder_hidden"].new_zeros( batch_size, self.decoder_output_dim) return state @overrides def forward( self, source_tokens_1: Dict[str, torch.LongTensor], source_tokens_2: Dict[str, torch.LongTensor], metadata: List[Dict[str, Any]], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: # 分成训练、验证/测试、预测,这三种情况分别考虑 # 1. 训练时:必然同时提供了target_tokens作为ground truth。 # 此时,只需要计算loss,无需beam search if self.training: assert target_tokens is not None state = self._encode(source_tokens_1, source_tokens_2) state["target_token_ids"] = target_tokens["tokens"] state = self._init_decoder_state(state) output_dict = self._forward_loss(target_tokens, state) output_dict["metadata"] = metadata return output_dict # 包含loss、metadata两项 # 2. 验证/测试时:self.training为false,但是提供了target_tokens。 # 此时,需要计算loss、运行beam search、计算评价指标 elif target_tokens: # 计算loss state = self._encode(source_tokens_1, source_tokens_2) state["target_token_ids"] = target_tokens["tokens"] state = self._init_decoder_state(state) output_dict = self._forward_loss(target_tokens, state) # 运行beam search state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) # 计算评价指标(BLEU) if self._tensor_based_metric is not None: # shape: (batch_size, beam_size, max_decoding_steps) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_decoding_steps) best_predictions = top_k_predictions[:, 0, :] # shape: (batch_size, target_seq_len) gold_tokens = target_tokens["tokens"] self._tensor_based_metric(best_predictions, gold_tokens) output_dict["metadata"] = metadata return output_dict # 包含loss、metadata、top-k、top-k log prob四项 # 3. 预测时:self.training为false,同时也没有提供target_tokens。 # 此时,只需要运行beam search执行top-k预测即可 else: state = self._encode(source_tokens_1, source_tokens_2) state = self._init_decoder_state(state) output_dict = {"metadata": metadata} predictions = self._forward_beam_search(state) output_dict.update(predictions) return output_dict # 包含metadata、top-k、top-k log prob三项 def _forward_loss( self, target_tokens: Dict[str, torch.Tensor], state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ 为输入的一个batch计算损失(仅在训练时调用)。 """ batch_size, target_seq_len = target_tokens["tokens"].size() # shape: (batch_size, seq_max_len_1) source_mask_1 = state["source_mask_1"] # shape: (batch_size, seq_max_len_2) source_mask_2 = state["source_mask_2"] # 需要生成的最大步数永远比目标序列(<start> ... <end>)的最大长度少1步 num_decoding_steps = target_seq_len - 1 step_log_likelihoods = [] # 存放每个时间步,目标词的log似然值 for timestep in range(num_decoding_steps): # t: 0..T # 当前时刻要输入的token id,shape (batch_size,) input_choices = target_tokens["tokens"][:, timestep] # 更新一步解码器状态(计算各类中间变量,例如attention分数、软门控分数) state = self._decoder_step(input_choices, state) # 获取decoder_hidden相对于两个source的attention分数 # shape: (batch_size, seq_max_len_1) attentive_weights_1 = state["attentive_weights_1"] # shape: (batch_size, seq_max_len_2) attentive_weights_2 = state["attentive_weights_2"] # 计算target_to_source,指明当前要输出的target (ground truth),是否出现在source之中 # shape: (batch_size, seq_max_len_1) target_to_source_1 = (state["source_token_ids_1"] == state["target_token_ids"][:, timestep + 1].unsqueeze(-1)) # shape: (batch_size, seq_max_len_2) target_to_source_2 = (state["source_token_ids_2"] == state["target_token_ids"][:, timestep + 1].unsqueeze(-1)) # 根据上面的信息计算当前时间步target token的对数似然 step_log_likelihood = self._get_ll_contrib( attentive_weights_1, attentive_weights_2, source_mask_1, source_mask_2, target_to_source_1, target_to_source_2, state["target_token_ids"][:, timestep + 1], state["gate_score"]) step_log_likelihoods.append(step_log_likelihood.unsqueeze(1)) # 将各个时间步的对数似然合并成一个tensor # shape: (batch_size, num_decoding_steps = target_seq_len - 1) log_likelihoods = torch.cat(step_log_likelihoods, 1) # 获取包含START和END的target mask # shape: (batch_size, target_seq_len) target_mask = util.get_text_field_mask(target_tokens) # 去掉第一个,不会作为目标词的START # shape: (batch_size, num_decoding_steps = target_seq_len - 1) target_mask = target_mask[:, 1:].float() # 将各个时间步上的对数似然tensor使用mask累加,得到整个时间序列的对数似然 log_likelihood = (log_likelihoods * target_mask).sum(dim=-1) loss = -log_likelihood.sum() / batch_size return {"loss": loss} def _decoder_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ 更新一步decoder状态。 """ # shape: (group_size, seq_max_len_1, encoder_output_dim_1) source_mask_1 = state["source_mask_1"].float() # shape: (group_size, seq_max_len_2, encoder_output_dim_2) source_mask_2 = state["source_mask_2"].float() # y_{t-1}, shape: (group_size, target_embedding_dim) embedded_input = self._target_embedder(last_predictions) # a_t, shape: (group_size, seq_max_len_1) state["attentive_weights_1"] = self._attention_1( state["decoder_hidden"], state["encoder_out_1"], source_mask_1) # a'_t, shape: (group_size, seq_max_len_2) state["attentive_weights_2"] = self._attention_2( state["decoder_hidden"], state["encoder_out_2"], source_mask_2) # c_t, shape: (group_size, encoder_output_dim_1) attentive_read_1 = util.weighted_sum(state["encoder_out_1"], state["attentive_weights_1"]) # c'_t, shape: (group_size, encoder_output_dim_2) attentive_read_2 = util.weighted_sum(state["encoder_out_2"], state["attentive_weights_2"]) # 计算软门控机制:lambda # shape: (group_size, target_embedding_dim + encoder_output_dim_1 + encoder_output_dim_2 + decoder_output_dim) gate_input = torch.cat((embedded_input, attentive_read_1, attentive_read_2, state["decoder_hidden"]), dim=-1) # shape: (group_size,) gate_projected = self._gate_projection_layer(gate_input).squeeze(-1) # shape: (group_size,) state["gate_score"] = torch.sigmoid(gate_projected) # shape: (group_size, target_embedding_dim + encoder_output_dim_1 + encoder_output_dim_2) decoder_input = torch.cat( (embedded_input, attentive_read_1, attentive_read_2), dim=-1) # 更新decoder状态(hidden和context/cell) state["decoder_hidden"], state["decoder_context"] = self._decoder_cell( decoder_input, (state["decoder_hidden"], state["decoder_context"])) return state def _get_ll_contrib(self, copy_scores_1: torch.Tensor, copy_scores_2: torch.Tensor, source_mask_1: torch.Tensor, source_mask_2: torch.Tensor, target_to_source_1: torch.Tensor, target_to_source_2: torch.Tensor, target_tokens: torch.Tensor, gate_score: torch.Tensor) -> torch.Tensor: """ 根据一个时间步的attention分数、黄金token,计算黄金token的对数似然。 参数: - copy_scores_1:对第一个source的注意力分值。 shape: (batch_size, seq_max_len_1) - copy_scores_2:对第二个source的注意力分值。 shape: (batch_size, seq_max_len_2) - source_mask_1:第一个source的mask shape: (batch_size, seq_max_len_1) - source_mask_2:第二个source的mask shape: (batch_size, seq_max_len_2) - target_to_source_1:目标词是否为第一个source对应位置的词 shape: (batch_size, seq_max_len_1) - target_to_source_2:目标词是否为第二个source对应位置的词 shape: (batch_size, seq_max_len_2) - target_tokens:当前时间步的目标词 shape: (batch_size,) - gate_score:从第一个source拷贝词语的概率(0-1之间) shape: (batch_size,) 返回: 当前时间步,生成目标词的对数似然(log-likelihood) shape: (batch_size,) """ # 计算第一个source的分值 # shape: (batch_size, seq_max_len_1) combined_log_probs_1 = (copy_scores_1 + 1e-45).log() + ( target_to_source_1.float() + 1e-45).log() + (source_mask_1.float() + 1e-45).log() # shape: (batch_size,) log_probs_1 = util.logsumexp( combined_log_probs_1) # log(exp(a[0]) + ... + exp(a[L])) # 计算第二个source的分值 # shape: (batch_size, seq_max_len_2) combined_log_probs_2 = (copy_scores_2 + 1e-45).log() + ( target_to_source_2.float() + 1e-45).log() + (source_mask_2.float() + 1e-45).log() # shape: (batch_size,) log_probs_2 = util.logsumexp( combined_log_probs_2) # log(exp(a[0]) + ... + exp(a[L])) # 计算 log(p1 * gate + p2 * (1-gate)) log_gate_score_1 = gate_score.log() # shape: (batch_size,) log_gate_score_2 = (1 - gate_score).log() # shape: (batch_size,) item_1 = (log_gate_score_1 + log_probs_1).unsqueeze( -1) # shape: (batch_size, 1) item_2 = (log_gate_score_2 + log_probs_2).unsqueeze( -1) # shape: (batch_size, 1) step_log_likelihood = util.logsumexp(torch.cat( (item_1, item_2), -1)) # shape: (batch_size,) return step_log_likelihood def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask_1"].size()[0] start_predictions = state["source_mask_1"].new_full( (batch_size, ), fill_value=self._start_index) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_search_step) return { "predicted_log_probs": log_probabilities, "predictions": all_top_k_predictions } def take_search_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ 用于beam_search。 参数: - last_predictions:上一时间步的预测结果 shape: (group_size,) - state:状态 返回: - final_log_probs:在全词表上的对数似然 shape: (group_size, target_vocab_size) - state:更新后的状态 说明:该函数用于提供给Beam Search使用,输入为上一个时间步的预测id(last_predictions, 初始为start_index),输出为全词表上的对数似然概率(final_log_probs)。 TODO: 考虑OOV情况(需要整体大改) """ # 更新一步decoder状态 state = self._decoder_step(last_predictions, state) # 对第一个source的拷贝概率值,shape: (group_size, seq_max_len_1) copy_scores_1 = state["attentive_weights_1"] # 对第二个source的拷贝概率值,shape: (group_size, seq_max_len_2) copy_scores_2 = state["attentive_weights_2"] # 概率值的门控,shape: (group_size,) gate_score = state["gate_score"] # 计算全词表上的对数似然 final_log_probs = self._gather_final_log_probs(copy_scores_1, copy_scores_2, gate_score, state) return final_log_probs, state def _gather_final_log_probs( self, copy_scores_1: torch.Tensor, copy_scores_2: torch.Tensor, gate_score: torch.Tensor, state: Dict[str, torch.Tensor]) -> torch.Tensor: """ 根据三个概率,计算全词表上的对数似然。 参数: - copy_scores_1:第一个source的复制概率(经过归一化) shape: (group_size, seq_max_len_1) - copy_scores_2:第二个source的复制概率(经过归一化) shape: (group_size, seq_max_len_2) - gate_score:门控的分数,决定source1共享多少比例(source2即贡献1-gate_score) shape: (group_size,) - state:当前时间步,更新后的解码状态 返回: - final_log_probs:全词表上的概率 shape: (group_size, target_vocab_size) """ # 获取group_size和两个序列的长度 group_size, seq_max_len_1 = copy_scores_1.size() group_size, seq_max_len_2 = copy_scores_2.size() # TODO: 这里默认了source和target使用同一个词表映射,否则需要source2target的映射 # (即source词在target词表的index),才能进行匹配 # shape: (group_size, seq_max_len_1) source_token_ids_1 = state["source_token_ids_1"] # shape: (group_size, seq_max_len_2) source_token_ids_2 = state["source_token_ids_2"] # 在序列上扩展gate_score # 需要和source1相乘的gate概率,shape: (group_size, seq_max_len_1) gate_1 = gate_score.expand(seq_max_len_1, -1).t() # 需要和source2相乘的gate概率,shape: (group_size, seq_max_len_2) gate_2 = (1 - gate_score).expand(seq_max_len_2, -1).t() # 加权后的source1分值,shape: (group_size, seq_max_len_1) copy_scores_1 = copy_scores_1 * gate_1 # 加权后的source2分值,shape: (group_size, seq_max_len_2) copy_scores_2 = copy_scores_2 * gate_2 # shape: (group_size, seq_max_len_1) log_probs_1 = (copy_scores_1 + 1e-45).log() # shape: (group_size, seq_max_len_2) log_probs_2 = (copy_scores_2 + 1e-45).log() # 初始化全词表上的概率为全0, shape: (group_size, target_vocab_size) final_log_probs = (state["decoder_hidden"].new_zeros( (group_size, self._target_vocab_size)) + 1e-45).log() for i in range(seq_max_len_1): # 遍历source1的所有时间步 # 当前时间步的预测概率,shape: (group_size, 1) log_probs_slice = log_probs_1[:, i].unsqueeze(-1) # 当前时间步的token ids,shape: (group_size, 1) source_to_target_slice = source_token_ids_1[:, i].unsqueeze(-1) # 选出要更新位置,原有的词表概率,shape: (group_size, 1) selected_log_probs = final_log_probs.gather( -1, source_to_target_slice) # 更新后的概率值(原有概率+更新概率,混合),shape: (group_size, 1) combined_scores = util.logsumexp( torch.cat((selected_log_probs, log_probs_slice), dim=-1)).unsqueeze(-1) # 将combined_scores设置回final_log_probs中 final_log_probs = final_log_probs.scatter(-1, source_to_target_slice, combined_scores) # 对source2也同样做一遍 for i in range(seq_max_len_2): log_probs_slice = log_probs_2[:, i].unsqueeze(-1) source_to_target_slice = source_token_ids_2[:, i].unsqueeze(-1) selected_log_probs = final_log_probs.gather( -1, source_to_target_slice) combined_scores = util.logsumexp( torch.cat((selected_log_probs, log_probs_slice), dim=-1)).unsqueeze(-1) final_log_probs = final_log_probs.scatter(-1, source_to_target_slice, combined_scores) return final_log_probs def _get_predicted_tokens( self, predicted_indices: Union[torch.Tensor, numpy.ndarray], batch_metadata: List[Any], n_best: int = None) -> List[Union[List[List[str]], List[str]]]: if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() predicted_tokens: List[Union[List[List[str]], List[str]]] = [] for top_k_predictions, metadata in zip(predicted_indices, batch_metadata): batch_predicted_tokens: List[List[str]] = [] for indices in top_k_predictions[:n_best]: tokens: List[str] = [] indices = list(indices) if self._end_index in indices: indices = indices[:indices.index(self._end_index)] for index in indices: token = self.vocab.get_token_from_index( index, self._target_namespace) tokens.append(token) batch_predicted_tokens.append(tokens) if n_best == 1: predicted_tokens.append(batch_predicted_tokens[0]) else: predicted_tokens.append(batch_predicted_tokens) return predicted_tokens @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: """ 将预测结果(tensor)解码成token序列。 """ predicted_tokens = self._get_predicted_tokens( output_dict["predictions"], output_dict["metadata"]) output_dict["predicted_tokens"] = predicted_tokens return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if not self.training: if self._tensor_based_metric is not None: all_metrics.update( self._tensor_based_metric.get_metric(reset=reset)) return all_metrics
class WNCGRnnModel(WNCGBaseModel): def __init__(self, input_gpv_dim, d_model, num_layers, n_vocab, input_amedas_seqlen, weather, cl, lr=0.001, dropout_p=0.2): super().__init__(d_model, weather, cl, dropout_p) d_meta_amedas = 64 self.num_layers = num_layers self.d_model = d_model """ Encoder """ # encoder for gpv self.gpv_encoder = MLPEncoder( input_gpv_dim, d_model, dropout_p=dropout_p) # single-layer MLP for encoding a gpv data self.rnn_encoder = nn.GRU( d_model, d_model, num_layers=1, bidirectional=True ) # single-layer BiGRU for encoding sequence of gpv data self.gpv_to_dmodel = nn.Linear(d_model * 2, d_model) # encoder for amedas self.amedas_to_dmodel = nn.Linear(input_amedas_seqlen, d_meta_amedas) # encoder for meta-data metaenc = {} metaenc["area"] = nn.Embedding(277, d_meta_amedas) metaenc["month"] = nn.Embedding(12, d_meta_amedas) metaenc["day"] = nn.Embedding(31, d_meta_amedas) metaenc["time"] = nn.Embedding(24, d_meta_amedas) metaenc["week"] = nn.Embedding(7, d_meta_amedas) self.meta_encoders = nn.ModuleDict(metaenc) # BiGRU for gpv, Linear for Amedas, Linear for Metadata self.input_to_dmodel = nn.Linear( (d_model * 2) + (d_meta_amedas * 4) + (d_meta_amedas * 5), d_model) self.relu = nn.ReLU() """ Decoder """ # word decoder self.token_decoder = TokenAttnDecoderRNN(d_model, d_meta_amedas, d_model, n_vocab, num_layers, weather, dropout_p) # weather label self.weather = weather # option for content agreement loss self.cl = cl # make the arguments global self.lr = lr # save the arguments self.save_hyperparameters() def encode(self, src_gpv, src_amedas, src_meta, src_comment): """ encode """ _, batch_size = src_comment.size() # encode gpv-data src_gpv = self.gpv_encoder(src_gpv) gpv_output, gpv_hidden = self.rnn_encoder(src_gpv) gpv_output = self.gpv_to_dmodel(gpv_output) # encode amedas-data src_amedas = self.amedas_to_dmodel(src_amedas) # encode meta-data emb_area = self.meta_encoders["area"](src_meta[0, :]) emb_month = self.meta_encoders["month"](src_meta[1, :]) emb_day = self.meta_encoders["day"](src_meta[2, :]) emb_time = self.meta_encoders["time"](src_meta[3, :]) emb_week = self.meta_encoders["week"](src_meta[4, :]) src_meta = torch.stack( [emb_area, emb_month, emb_day, emb_time, emb_week], dim=0) gpv_hidden = torch.cat([gpv_output[0, :, :], gpv_output[-1, :, :]], dim=1) # (batch_size, d_model * 2 * 2) amedas_hidden = src_amedas.transpose(0, 1).reshape( batch_size, -1) # (batch_size, num_amedas_types * d_model) meta_hidden = src_meta.transpose(0, 1).reshape( batch_size, -1) # (batch_size, num_meta_types * d_model) # initital state of decoder data_h = self.relu( self.input_to_dmodel( torch.cat([gpv_hidden, amedas_hidden, meta_hidden], dim=1))) # (batch_size, d_model) encoder_hidden = self.reset(data_h) return gpv_output, src_amedas, src_meta, encoder_hidden def reset(self, hidden_state): # initialize hidden states of word decoder batch_size = hidden_state.size(0) decoder_hidden = torch.zeros( (self.num_layers, batch_size, self.d_model), dtype=torch.float32).to(self.device) nn.init.normal_(decoder_hidden, mean=0, std=0.05) decoder_hidden[0, :, :] = hidden_state return decoder_hidden def forward(self, src_gpv, src_amedas, src_meta, src_comment): """[summary] Args: src_gpv ([type]): [description] src_amedas ([type]): [description] src_meta ([type]): [description] src_comment ([type]): [description] Returns: [type]: [description] """ """ encode GPV/AMeDAS/Meta""" gpv_output, amedas_output, meta_output, encoder_hidden = \ self.encode(src_gpv, src_amedas, src_meta, src_comment) # initialize outputs of weather labels and weather hidden ZERO = torch.zeros(1, 1).to(self.device) sunny_out, cloudy_out, rain_out, snow_out, weather_hidden = \ ZERO, ZERO, ZERO, ZERO, None """ decode weather labels """ if self.weather == "label": sunny_out, sunny_hidden = self.sunny_decoder(encoder_hidden[0]) cloudy_out, cloudy_hidden = self.cloudy_decoder(encoder_hidden[0]) rain_out, rain_hidden = self.rain_decoder(encoder_hidden[0]) snow_out, snow_hidden = self.snow_decoder(encoder_hidden[0]) weather_hidden = torch.stack( [sunny_hidden, cloudy_hidden, rain_hidden, snow_hidden], dim=0) """ decode tokens """ token_out = [] tgt_word_embeddings = [] hidden = encoder_hidden # initial state of decoder for word_input in src_comment: output, hidden, word_emb = self.token_decoder( word_input, hidden, gpv_output, amedas_output, meta_output, weather_hidden) token_out.append(output) tgt_word_embeddings.append(word_emb) token_out = torch.stack(token_out, dim=0) tgt_text_embed = torch.cat(tgt_word_embeddings, dim=0) return (F.log_softmax(token_out, dim=-1), \ F.log_softmax(sunny_out, dim=-1), F.log_softmax(cloudy_out, dim=-1), \ F.log_softmax(rain_out, dim=-1), F.log_softmax(snow_out, dim=-1), \ weather_hidden, tgt_text_embed) def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs): """[summary] Args: epoch ([type]): [description] batch_idx ([type]): [description] optimizer ([type]): [description] optimizer_idx ([type]): [description] optimizer_closure ([type]): [description] on_tpu ([type]): [description] using_native_amp ([type]): [description] using_lbfgs ([type]): [description] """ optimizer.step(closure=optimizer_closure) optimizer.zero_grad() def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) lr_scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=1, min_lr=1e-5, verbose=True) scheduler = { 'scheduler': lr_scheduler, # The LR scheduler instance (required) 'interval': 'epoch', # The unit of the scheduler's step size 'frequency': 1, # The frequency of the scheduler 'reduce_on_plateau': True, # For ReduceLROnPlateau scheduler 'monitor': 'val_loss', # Metric for ReduceLROnPlateau to monitor 'strict': True, # Whether to crash the training if `monitor` is not found 'name': None, # Custom name for LearningRateMonitor to use } return [optimizer], [scheduler] def greedy_token_decode(self, hidden, gpv_output, amedas_output, meta_output, weather_hidden, token_generation_limit=128): _, batch_size, hidden_size = hidden.size() decoded_batch = torch.zeros((batch_size, token_generation_limit)) word_input = torch.tensor([[IDs.BOS.value] for _ in range(batch_size)], dtype=torch.long).to(self.device) for idx in range(token_generation_limit): output, hidden, _ = self.token_decoder(word_input, hidden, gpv_output, amedas_output, meta_output, weather_hidden) topv, topi = output.data.topk( 1) # [batch_size, vocab_size] get candidates decoded_batch[:, idx] = topi.view(-1) word_input = topi return decoded_batch.detach().tolist() @torch.no_grad() def beam_token_decode(self, hidden, gpv_output, amedas_output, meta_output, weather_hidden, beam_width=5): max_steps = 128 # The maximum number of decoding steps to take, self.beam_search = BeamSearch(end_index=IDs.EOS.value, max_steps=max_steps, beam_size=beam_width) batch_size = hidden.size(1) start_predictions = torch.tensor([IDs.BOS.value] * batch_size, dtype=torch.long, device=self.device) start_state = { "prev_tokens": torch.zeros(batch_size, 0, dtype=torch.long, device=self.device), "decoder_hidden": hidden } def step(last_tokens, current_state, t): """ Args: last_tokens: (group_size,) current_state: {} t: int """ nonlocal gpv_output nonlocal amedas_output nonlocal meta_output nonlocal weather_hidden group_size = last_tokens.size(0) # cocatenate prev_tokens with last_tokens prev_tokens = torch.cat( [current_state["prev_tokens"], last_tokens.unsqueeze(1)], dim=-1) # [B*k, t+1] # expand context hiddens for beam search decoding if group_size != gpv_output.size(1): gpv_output = gpv_output.unsqueeze(2)\ .expand(gpv_output.size(0), gpv_output.size(1), beam_width, gpv_output.size(-1))\ .reshape(gpv_output.size(0), gpv_output.size(1) * beam_width, gpv_output.size(-1)) amedas_output = amedas_output.unsqueeze(2)\ .expand(amedas_output.size(0), amedas_output.size(1), beam_width, amedas_output.size(-1))\ .reshape(amedas_output.size(0), amedas_output.size(1) * beam_width, amedas_output.size(-1)) meta_output = meta_output.unsqueeze(2)\ .expand(meta_output.size(0), meta_output.size(1), beam_width, meta_output.size(-1))\ .reshape(meta_output.size(0), meta_output.size(1) * beam_width, meta_output.size(-1)) weather_hidden = weather_hidden.unsqueeze(2)\ .expand(weather_hidden.size(0), weather_hidden.size(1), beam_width, weather_hidden.size(-1))\ .reshape(weather_hidden.size(0), weather_hidden.size(1) * beam_width, weather_hidden.size(-1)) if weather_hidden is not None else None # decode for one step using decoder decoder_output, decoder_hidden, _ = self.token_decoder( prev_tokens[:, -1], current_state["decoder_hidden"], gpv_output, amedas_output, meta_output, weather_hidden) current_state["prev_tokens"] = prev_tokens # update prev_tokens current_state[ "decoder_hidden"] = decoder_hidden # update decoder_hidden return (decoder_output, current_state) predictions, log_probs = self.beam_search.search( start_predictions=start_predictions, start_state=start_state, step=step) return predictions, log_probs
class Seq2seqPlmsGenerator(Model): def __init__(self, vocab: Vocabulary, pretrained_model_path, beam_size=5, max_decoding_steps=140, indexer=None): super().__init__(vocab) self.plm = MT5ForConditionalGeneration.from_pretrained(pretrained_model_path) self._indexer = indexer or PretrainedTransformerIndexer(pretrained_model_path, namespace="tokens") ## self._start_id = self.plm.config.decoder_start_token_id ## self._end_id = self.plm.config.eos_token_id # self._decoder_start_id = self.plm.config.decoder_start_token_id self._end_id = self.plm.config.eos_token_id # self._pad_id = self.plm.config.pad_token_id # self._beam_search = BeamSearch( self._end_id, max_steps=max_decoding_steps, beam_size=beam_size or 1 ) self._rouge = ROUGE(exclude_indices={self._start_id, self._pad_id, self._end_id}) self._bleu = BLEU(exclude_indices={self._start_id, self._pad_id, self._end_id}) @overrides def forward(self, source_tokens, target_tokens=None) -> Dict[str, torch.Tensor]: inputs = source_tokens targets = target_tokens input_ids, input_mask = inputs["tokens"]["token_ids"], inputs["tokens"]["mask"] outputs = {} # If no targets are provided, then shift input to right by 1. Bart already does this internally # but it does not use them for loss calculation. if targets is not None: target_ids, target_mask = targets["tokens"]["token_ids"], targets["tokens"]["mask"] else: target_ids = input_ids[:, 1:] target_mask = input_mask[:, 1:] if self.training: # training outputs = self.plm(input_ids=input_ids, attention_mask=input_mask, decoder_input_ids=target_ids[:, :-1].contiguous(), decoder_attention_mask=target_mask[:, :-1].contiguous(), use_cache=False, return_dict=True) outputs["decoder_logits"] = outputs.logits outputs["loss"] = sequence_cross_entropy_with_logits( outputs.logits, cast(torch.LongTensor, target_ids[:, 1:].contiguous()), cast(torch.BoolTensor, target_mask[:, 1:].contiguous()), label_smoothing=0.1, average="token", ) elif targets is not None: # validation outputs = self.plm(input_ids=input_ids, attention_mask=input_mask, decoder_input_ids=target_ids[:, :-1].contiguous(), decoder_attention_mask=target_mask[:, :-1].contiguous(), use_cache=False, return_dict=True) outputs["decoder_logits"] = outputs.logits outputs["loss"] = sequence_cross_entropy_with_logits( outputs.logits, cast(torch.LongTensor, target_ids[:, 1:].contiguous()), cast(torch.BoolTensor, target_mask[:, 1:].contiguous()), label_smoothing=0.1, ) self._rouge(torch.argmax(outputs.logits, -1), target_ids) self._bleu(torch.argmax(outputs.logits, -1), target_ids) else: #prediction # Use decoder start id and start of sentence to start decoder initial_decoder_ids = torch.tensor( [[self._decoder_start_id]], dtype=input_ids.dtype, device=input_ids.device, ).repeat(input_ids.shape[0], 1) inital_state = { "input_ids": input_ids, "input_mask": input_mask, } beam_result = self._beam_search.search( initial_decoder_ids, inital_state, self.take_step ) predictions = beam_result[0] logger.info(beam_result) max_pred_indices = ( beam_result[1].argmax(dim=-1).view(-1, 1, 1).expand(-1, -1, predictions.shape[-1]) ) predictions = predictions.gather(dim=1, index=max_pred_indices).squeeze(dim=1) self._rouge(predictions, target_ids) self._bleu(predictions, target_ids) outputs["predictions"] = predictions outputs["log_probabilities"] = ( beam_result[1].gather(dim=-1, index=max_pred_indices[..., 0]).squeeze(dim=-1) ) self.make_output_human_readable(outputs) return outputs @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics: Dict[str, float] = {} if not self.training: metrics.update(self._rouge.get_metric(reset=reset)) metrics.update(self._bleu.get_metric(reset=reset)) return metrics @staticmethod def _decoder_cache_to_dict(decoder_cache: DecoderCacheType) -> Dict[str, torch.Tensor]: cache_dict = {} for layer_index, layer_cache in enumerate(decoder_cache): # Each layer caches the key and value tensors for its self-attention and cross-attention. # Hence the `layer_cache` tuple has 4 elements. assert len(layer_cache) == 4 for tensor_index, tensor in enumerate(layer_cache): key = f"decoder_cache_{layer_index}_{tensor_index}" cache_dict[key] = tensor return cache_dict def _dict_to_decoder_cache(self, cache_dict: Dict[str, torch.Tensor]) -> DecoderCacheType: decoder_cache = [] for layer_index in range(self.plm.config.num_layers): base_key = f"decoder_cache_{layer_index}_" layer_cache = ( cache_dict[base_key + "0"], cache_dict[base_key + "1"], cache_dict[base_key + "2"], cache_dict[base_key + "3"], ) decoder_cache.append(layer_cache) assert decoder_cache return tuple(decoder_cache) def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], step: int ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take step during beam search. # Parameters last_predictions : `torch.Tensor` The predicted token ids from the previous step. Shape: `(group_size,)` state : `Dict[str, torch.Tensor]` State required to generate next set of predictions step : `int` The time step in beam search decoding. # Returns `Tuple[torch.Tensor, Dict[str, torch.Tensor]]` A tuple containing logits for the next tokens of shape `(group_size, target_vocab_size)` and an updated state dictionary. """ if len(last_predictions.shape) == 1: last_predictions = last_predictions.unsqueeze(-1) decoder_cache = None decoder_cache_dict = { k: state[k].contiguous() for k in state if k not in {"input_ids", "input_mask", "encoder_states"} } if len(decoder_cache_dict) != 0: decoder_cache = self._dict_to_decoder_cache(decoder_cache_dict) encoder_outputs = (state["encoder_states"],) if "encoder_states" in state else None outputs = self.plm( input_ids=state["input_ids"] if encoder_outputs is None else None, attention_mask=state["input_mask"], encoder_outputs=encoder_outputs, decoder_input_ids=last_predictions, past_key_values=decoder_cache, use_cache=True, return_dict=True, ) logits = outputs.logits[:, -1, :] log_probabilities = F.log_softmax(logits, dim=-1) decoder_cache = outputs.past_key_values if decoder_cache is not None: decoder_cache_dict = self._decoder_cache_to_dict(decoder_cache) state.update(decoder_cache_dict) state["encoder_states"] = outputs.encoder_last_hidden_state return log_probabilities, state @overrides def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: """ # Parameters output_dict : `Dict[str, torch.Tensor]` A dictionary containing a batch of predictions with key `predictions`. The tensor should have shape `(batch_size, max_sequence_length)` # Returns `Dict[str, Any]` Original `output_dict` with an additional `predicted_tokens` key that maps to a list of lists of tokens. """ predictions = output_dict["predictions"] predicted_tokens = [None] * predictions.shape[0] for i in range(predictions.shape[0]): predicted_tokens[i] = self._indexer.indices_to_tokens( {"token_ids": predictions[i].tolist()}, self.vocab, ) output_dict["predicted_tokens"] = predicted_tokens # type: ignore output_dict["predicted_text"] = self._indexer._tokenizer.batch_decode( predictions.tolist(), skip_special_tokens=True ) return output_dict
class MachampSeq2SeqDecoder(Model): """ An autoregressive decoder that can be used for most seq2seq tasks. # Parameters vocab : `Vocabulary`, required Vocabulary containing source and target vocabularies. They may be under the same namespace (`tokens`) or the target tokens can have a different namespace, in which case it needs to be specified as `target_namespace`. max_decoding_steps : `int` Maximum length of decoded sequences. attention : `Attention`, optional (default = `None`) If you want to use attention to get a dynamic summary of the encoder outputs at each step of decoding, this is the function used to compute similarity between the decoder hidden state and encoder outputs. target_namespace : `str`, optional (default = `'tokens'`) If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. target_embedding_dim : `int`, optional (default = `'source_embedding_dim'`) You can specify an embedding dimensionality for the target side. If not, we'll use the same value as the source embedder's. beam_size : `int`, optional (default = `None`) Width of the beam for beam search. If not specified, greedy decoding is used. scheduled_sampling_ratio : `float`, optional (default = `0.`) At each timestep during training, we sample a random number between 0 and 1, and if it is not less than this value, we use the ground truth labels for the whole batch. Else, we use the predictions from the previous time step for the whole batch. If this value is 0.0 (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not using target side ground truth labels. See the following paper for more information: [Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al., 2015](https://arxiv.org/abs/1506.03099). use_bleu : `bool`, optional (default = `True`) If True, the BLEU metric will be calculated during validation. ngram_weights : `Iterable[float]`, optional (default = `(0.25, 0.25, 0.25, 0.25)`) Weights to assign to scores for each ngram size. """ def __init__( self, task: str, vocab: Vocabulary, input_dim: int, max_decoding_steps: int, loss_weight: float = 1.0, attention: Attention = None, beam_size: int = None, target_namespace: str = "target_tokens", target_embedding_dim: int = None, scheduled_sampling_ratio: float = 0.0, use_bleu: bool = True, bleu_ngram_weights: Iterable[float] = (0.25, 0.25, 0.25, 0.25), target_decoder_layers: int = 1, **kwargs, ) -> None: super().__init__(vocab, **kwargs) self.task = task self.vocab = vocab self.loss_weight = loss_weight self._target_namespace = task + '_target_words' self._target_decoder_layers = target_decoder_layers self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) self._bleu = BLEU(bleu_ngram_weights, exclude_indices={ pad_index, self._end_index, self._start_index }) else: self._bleu = None self.metrics = {"bleu": self._bleu} # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) num_classes = self.vocab.get_vocab_size( namespace=self._target_namespace) # Attention mechanism applied to the encoder output for each step. self._attention = attention # The input to the decoder is just the previous target embedding. target_embedding_dim = target_embedding_dim or self._encoder_output_dim self._decoder_input_dim = target_embedding_dim # Dense embedding of vocab words in the target space. self._target_embedder = Embedding(num_embeddings=num_classes, embedding_dim=target_embedding_dim) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = input_dim self._decoder_output_dim = self._encoder_output_dim if self._attention: # If using attention, a weighted average over encoder outputs will be concatenated # to the previous target embedding to form the input to the decoder at each # time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim else: # Otherwise, the input to the decoder is just the previous target embedding. self._decoder_input_dim = target_embedding_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. if self._target_decoder_layers > 1: self._decoder_cell = LSTM( self._decoder_input_dim, self._decoder_output_dim, self._target_decoder_layers, ) else: self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) @overrides def forward( self, # type: ignore embedded_text: torch.LongTensor, source_mask: torch.LongTensor, target_tokens: TextFieldTensors = None) -> Dict[str, torch.Tensor]: state = {"encoder_outputs": embedded_text, "source_mask": source_mask} if target_tokens: state = self._init_decoder_state(state) # The `_forward_loop` decodes the input sequence and computes the loss during training # and validation. output_dict = self._forward_loop(state, target_tokens) else: output_dict = {} if not self.training: state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) if target_tokens and self._bleu: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._bleu(best_predictions, target_tokens["tokens"]["tokens"]) return output_dict @overrides def make_output_human_readable( self, output_dict: Dict[str, Any]) -> Dict[str, Any]: """ Finalize predictions. This method overrides `Model.make_output_human_readable`, which gets called after `Model.forward`, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the `forward` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called `predicted_tokens` to the `output_dict`. """ predicted_indices = output_dict #["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for top_k_predictions in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # we want top-k results. if len(top_k_predictions.shape) == 1: top_k_predictions = [top_k_predictions] batch_predicted_tokens = [] for indices in top_k_predictions: indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] batch_predicted_tokens.append(predicted_tokens) all_predicted_tokens.append(batch_predicted_tokens) return all_predicted_tokens def take_step(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], step: int) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. # Parameters last_predictions : `torch.Tensor` A tensor of shape `(group_size,)`, which gives the indices of the predictions during the last time step. state : `Dict[str, torch.Tensor]` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape `(group_size, *)`, where `*` can be any other number of dimensions. step : `int` The time step in beam search decoding. # Returns Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of `(log_probabilities, updated_state)`, where `log_probabilities` is a tensor of shape `(group_size, num_classes)` containing the predicted log probability of each class for the next step, for each item in the group, while `updated_state` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though `group_size` is not necessarily equal to `batch_size`, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections( last_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state def _init_decoder_state( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( state["encoder_outputs"], state["source_mask"], bidirectional=False) # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = final_encoder_output # shape: (batch_size, decoder_output_dim) state["decoder_context"] = state["encoder_outputs"].new_zeros( batch_size, self._decoder_output_dim) if self._target_decoder_layers > 1: # shape: (num_layers, batch_size, decoder_output_dim) state["decoder_hidden"] = ( state["decoder_hidden"].unsqueeze(0).repeat( self._target_decoder_layers, 1, 1)) # shape: (num_layers, batch_size, decoder_output_dim) state["decoder_context"] = ( state["decoder_context"].unsqueeze(0).repeat( self._target_decoder_layers, 1, 1)) return state def _forward_loop( self, state: Dict[str, torch.Tensor], target_tokens: TextFieldTensors = None) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size()[0] if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"]["tokens"] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full((batch_size, ), fill_value=self._start_index, dtype=torch.long) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand( 1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size,) input_choices = last_predictions elif not target_tokens: # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] # shape: (batch_size, num_classes) output_projections, state = self._prepare_output_projections( input_choices, state) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape: (batch_size, num_classes) class_probabilities = F.softmax(output_projections, dim=-1) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(class_probabilities, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) output_dict = { "predictions": predictions, "class_probabilities": predictions } if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) * self.loss_weight output_dict["loss"] = loss return output_dict def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self._start_index, dtype=torch.long) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _prepare_output_projections( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (num_layers, group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (num_layers, group_size, decoder_output_dim) decoder_context = state["decoder_context"] # shape: (group_size, target_embedding_dim) embedded_input = self._target_embedder(last_predictions) if self._attention: # shape: (group_size, encoder_output_dim) if self._target_decoder_layers > 1: attended_input = self._prepare_attended_input( decoder_hidden[0], encoder_outputs, source_mask) else: attended_input = self._prepare_attended_input( decoder_hidden, encoder_outputs, source_mask) # shape: (group_size, decoder_output_dim + target_embedding_dim) decoder_input = torch.cat((attended_input, embedded_input), -1) else: # shape: (group_size, target_embedding_dim) decoder_input = embedded_input if self._target_decoder_layers > 1: # shape: (1, batch_size, target_embedding_dim) #TODO why is this necessary? decoder_input = decoder_input.unsqueeze(0).contiguous() decoder_context = decoder_context.contiguous() decoder_hidden = decoder_hidden.contiguous() # shape (decoder_hidden): (num_layers, batch_size, decoder_output_dim) # shape (decoder_context): (num_layers, batch_size, decoder_output_dim) # TODO (epwalsh): remove the autocast(False) once torch's AMP is working for LSTMCells. with torch.cuda.amp.autocast(False): _, (decoder_hidden, decoder_context) = self._decoder_cell( decoder_input.float(), (decoder_hidden.float(), decoder_context.float())) else: # shape (decoder_hidden): (batch_size, decoder_output_dim) # shape (decoder_context): (batch_size, decoder_output_dim) # TODO (epwalsh): remove the autocast(False) once torch's AMP is working for LSTMCells. with torch.cuda.amp.autocast(False): decoder_hidden, decoder_context = self._decoder_cell( decoder_input.float(), (decoder_hidden.float(), decoder_context.float())) state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context # shape: (group_size, num_classes) if self._target_decoder_layers > 1: output_projections = self._output_projection_layer( decoder_hidden[-1]) else: output_projections = self._output_projection_layer(decoder_hidden) return output_projections, state def _prepare_attended_input( self, decoder_hidden_state: torch.LongTensor = None, encoder_outputs: torch.LongTensor = None, encoder_outputs_mask: torch.BoolTensor = None, ) -> torch.Tensor: """Apply attention over encoder outputs and decoder state.""" # shape: (batch_size, max_input_sequence_length) input_weights = self._attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask) # shape: (batch_size, encoder_output_dim) attended_input = util.weighted_sum(encoder_outputs, input_weights) return attended_input @staticmethod def _get_loss( logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.BoolTensor, ) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of `targets` is expected to be greater than that of `logits` because the decoder does not need to compute the output corresponding to the last timestep of `targets`. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: main_metrics: Dict[str, float] = {} if self._bleu: # and not self.training: main_metrics = { f".run/{self.task}/{metric_name}": metric.get_metric(reset) for metric_name, metric in self.metrics.items() } return {**main_metrics}
class Bart(Model): """ BART model from the paper "BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension" (https://arxiv.org/abs/1910.13461). The Bart model here uses a language modeling head and thus can be used for text generation. """ def __init__( self, model_name: str, vocab: Vocabulary, indexer: PretrainedTransformerIndexer = None, max_decoding_steps: int = 140, beam_size: int = 4, encoder: Seq2SeqEncoder = None, ): """ # Parameters model_name : `str`, required Name of the pre-trained BART model to use. Available options can be found in `transformers.models.bart.modeling_bart.BART_PRETRAINED_MODEL_ARCHIVE_MAP`. vocab : `Vocabulary`, required Vocabulary containing source and target vocabularies. indexer : `PretrainedTransformerIndexer`, optional (default = `None`) Indexer to be used for converting decoded sequences of ids to to sequences of tokens. max_decoding_steps : `int`, optional (default = `128`) Number of decoding steps during beam search. beam_size : `int`, optional (default = `5`) Number of beams to use in beam search. The default is from the BART paper. encoder : `Seq2SeqEncoder`, optional (default = `None`) Encoder to used in BART. By default, the original BART encoder is used. """ super().__init__(vocab) self.bart = BartForConditionalGeneration.from_pretrained(model_name) self._indexer = indexer or PretrainedTransformerIndexer( model_name, namespace="tokens") self._start_id = self.bart.config.bos_token_id # CLS self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id self._end_id = self.bart.config.eos_token_id # SEP self._pad_id = self.bart.config.pad_token_id # PAD self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_id, max_steps=max_decoding_steps, beam_size=beam_size or 1) self._rouge = ROUGE( exclude_indices={self._start_id, self._pad_id, self._end_id}) self._bleu = BLEU( exclude_indices={self._start_id, self._pad_id, self._end_id}) # Replace bart encoder with given encoder. We need to extract the two embedding layers so that # we can use them in the encoder wrapper if encoder is not None: assert (encoder.get_input_dim() == encoder.get_output_dim() == self.bart.config.hidden_size) self.bart.model.encoder = _BartEncoderWrapper( encoder, self.bart.model.encoder.embed_tokens, self.bart.model.encoder.embed_positions, ) @overrides def forward( self, source_tokens: TextFieldTensors, target_tokens: TextFieldTensors = None) -> Dict[str, torch.Tensor]: """ Performs the forward step of Bart. # Parameters source_tokens : `TextFieldTensors`, required The source tokens for the encoder. We assume they are stored under the `tokens` key. target_tokens : `TextFieldTensors`, optional (default = `None`) The target tokens for the decoder. We assume they are stored under the `tokens` key. If no target tokens are given, the source tokens are shifted to the right by 1. # Returns `Dict[str, torch.Tensor]` During training, this dictionary contains the `decoder_logits` of shape `(batch_size, max_target_length, target_vocab_size)` and the `loss`. During inference, it contains `predictions` of shape `(batch_size, max_decoding_steps)` and `log_probabilities` of shape `(batch_size,)`. """ inputs = source_tokens targets = target_tokens input_ids, input_mask = inputs["tokens"]["token_ids"], inputs[ "tokens"]["mask"] outputs = {} # If no targets are provided, then shift input to right by 1. Bart already does this internally # but it does not use them for loss calculation. if targets is not None: target_ids, target_mask = targets["tokens"]["token_ids"], targets[ "tokens"]["mask"] else: target_ids = input_ids[:, 1:] target_mask = input_mask[:, 1:] if self.training: decoder_logits = self.bart( input_ids=input_ids, attention_mask=input_mask, decoder_input_ids=target_ids[:, :-1].contiguous(), decoder_attention_mask=target_mask[:, :-1].contiguous(), use_cache=False, )[0] outputs["decoder_logits"] = decoder_logits # The BART paper mentions label smoothing of 0.1 for sequence generation tasks outputs["loss"] = sequence_cross_entropy_with_logits( decoder_logits, target_ids[:, 1:].contiguous(), target_mask[:, 1:].contiguous(), label_smoothing=0.1, average="token", ) else: # Use decoder start id and start of sentence to start decoder initial_decoder_ids = torch.tensor( [[self._decoder_start_id, self._start_id]], dtype=input_ids.dtype, device=input_ids.device, ).repeat(input_ids.shape[0], 1) inital_state = { "input_ids": input_ids, "input_mask": input_mask, "encoder_states": None, } beam_result = self._beam_search.search(initial_decoder_ids, inital_state, self.take_step) predictions = beam_result[0] max_pred_indices = (beam_result[1].argmax(dim=-1).view( -1, 1, 1).expand(-1, -1, predictions.shape[-1])) predictions = predictions.gather( dim=1, index=max_pred_indices).squeeze(dim=1) self._rouge(predictions, target_ids) self._bleu(predictions, target_ids) outputs["predictions"] = predictions outputs["log_probabilities"] = (beam_result[1].gather( dim=-1, index=max_pred_indices[..., 0]).squeeze(dim=-1)) self.make_output_human_readable(outputs) return outputs @staticmethod def _decoder_cache_to_dict(decoder_cache): cache_dict = {} for layer_index, layer_cache in enumerate(decoder_cache): for attention_name, attention_cache in layer_cache.items(): for tensor_name, cache_value in attention_cache.items(): key = (layer_index, attention_name, tensor_name) cache_dict[key] = cache_value return cache_dict @staticmethod def _dict_to_decoder_cache(cache_dict): decoder_cache = [] for key, cache_value in cache_dict.items(): # Split key and extract index and dict keys layer_idx, attention_name, tensor_name = key # Extend decoder_cache to fit layer_idx + 1 layers decoder_cache = decoder_cache + [ {} for _ in range(layer_idx + 1 - len(decoder_cache)) ] cache = decoder_cache[layer_idx] if attention_name not in cache: cache[attention_name] = {} assert tensor_name not in cache[attention_name] cache[attention_name][tensor_name] = cache_value return decoder_cache def take_step(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], step: int) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take step during beam search. # Parameters last_predictions : `torch.Tensor` The predicted token ids from the previous step. Shape: `(group_size,)` state : `Dict[str, torch.Tensor]` State required to generate next set of predictions step : `int` The time step in beam search decoding. # Returns `Tuple[torch.Tensor, Dict[str, torch.Tensor]]` A tuple containing logits for the next tokens of shape `(group_size, target_vocab_size)` and an updated state dictionary. """ if len(last_predictions.shape) == 1: last_predictions = last_predictions.unsqueeze(-1) # Only the last predictions are needed for the decoder, but we need to pad the decoder ids # to not mess up the positional embeddings in the decoder. padding_size = 0 if step > 0: padding_size = step + 1 padding = torch.full( (last_predictions.shape[0], padding_size), self._pad_id, dtype=last_predictions.dtype, device=last_predictions.device, ) last_predictions = torch.cat([padding, last_predictions], dim=-1) decoder_cache = None decoder_cache_dict = { k: (state[k].contiguous() if state[k] is not None else None) for k in state if k not in {"input_ids", "input_mask", "encoder_states"} } if len(decoder_cache_dict) != 0: decoder_cache = self._dict_to_decoder_cache(decoder_cache_dict) log_probabilities = None for i in range(padding_size, last_predictions.shape[1]): encoder_outputs = ((state["encoder_states"], ) if state["encoder_states"] is not None else None) outputs = self.bart( input_ids=state["input_ids"], attention_mask=state["input_mask"], encoder_outputs=encoder_outputs, decoder_input_ids=last_predictions[:, :i + 1], past_key_values=decoder_cache, use_cache=True, ) decoder_log_probabilities = F.log_softmax(outputs[0][:, 0], dim=-1) if log_probabilities is None: log_probabilities = decoder_log_probabilities else: idx = last_predictions[:, i].view(-1, 1) log_probabilities = decoder_log_probabilities + log_probabilities.gather( dim=-1, index=idx) decoder_cache = outputs[1] state["encoder_states"] = outputs[2] if decoder_cache is not None: decoder_cache_dict = self._decoder_cache_to_dict(decoder_cache) state.update(decoder_cache_dict) return log_probabilities, state @overrides def make_output_human_readable( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: """ # Parameters output_dict : `Dict[str, torch.Tensor]` A dictionary containing a batch of predictions with key `predictions`. The tensor should have shape `(batch_size, max_sequence_length)` # Returns `Dict[str, Any]` Original `output_dict` with an additional `predicted_tokens` key that maps to a list of lists of tokens. """ predictions = output_dict["predictions"] predicted_tokens = [None] * predictions.shape[0] for i in range(predictions.shape[0]): predicted_tokens[i] = self._indexer.indices_to_tokens( {"token_ids": predictions[i].tolist()}, self.vocab) output_dict["predicted_tokens"] = predicted_tokens return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics: Dict[str, float] = {} if not self.training: metrics.update(self._rouge.get_metric(reset=reset)) metrics.update(self._bleu.get_metric(reset=reset)) return metrics
class EncoderDecoder(Model): def __init__(self, source_embedder: TextFieldEmbedder, target_embedder: TextFieldEmbedder, max_steps: int, encoder: Encoder, decoder: Decoder, hidden_size: int, vocab: Vocabulary, teacher_force_ratio: float, regularizer: RegularizerApplicator = None) -> None: super().__init__(vocab, regularizer) # TODO: Workon BeamSearch, try to switch to OpenNMT BeamSearch but implement our own beamsearch first self.max_steps = max_steps self.hidden_size = hidden_size self.source_embedder = source_embedder self.target_embedder = target_embedder self.encoder = encoder self.decoder = decoder self.teacher_force_ratio = teacher_force_ratio self.decoder.add_vocab(self.vocab) self.padding_idx = self.vocab.get_token_index(DEFAULT_PADDING_TOKEN) self.start_idx = self.vocab.get_token_index(START_SYMBOL) self.end_idx = self.vocab.get_token_index(END_SYMBOL) self.unk_idx = self.vocab.get_token_index(DEFAULT_OOV_TOKEN) self.beam = BeamSearch(self.end_idx, max_steps=self.max_steps, beam_size=5) self.criterion = CrossEntropyLoss(ignore_index=self.padding_idx) # noinspection PyMethodMayBeStatic def init_enc_state( self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: source_mask = util.get_text_field_mask(source_tokens) source_lengths = get_lengths_from_binary_sequence_mask(source_mask) state = { 'source_mask': source_mask, # (B, L) 'source_lengths': source_lengths, # (L) 'source_tokens': source_tokens['tokens'], } return state def init_dec_state( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: states = state['encoder_states'] batch_size = states.size(0) length = states.size(1) state['context'] = states.new_zeros((batch_size, 1, self.hidden_size)) state['dec_state'] = state['hidden'] state['coverage'] = states.new_zeros( (batch_size, length, 1)) # (B, L, 1) return state def forward(self, source_tokens: Dict[str, torch.Tensor], source_text: Dict[str, Any], source_ids: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.Tensor] = None, saliency_values: torch.Tensor = None) \ -> Dict[str, torch.Tensor]: """ The forward function of the encoder and decoder model :param source_ids: The source ids that is unique to the document :param source_text: The raw text of source sequence :param saliency_values: The saliency values for source tokens :param source_tokens: Indexes of states tokens :param target_tokens: Indexes of target tokens :return: The loss and prediction of the model """ state = self._encode(source_tokens) output_dict = {} if target_tokens: state = self._decode(source_ids, target_tokens, state) output_dict['loss'] = self._compute_loss(target_tokens, state) if not self.training and not target_tokens: output_dict['predictions'] = self._forward_beam_search( state, source_ids) return output_dict def _forward_beam_search( self, state: Dict[str, torch.Tensor], source_ids: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" state = self.init_dec_state(state) state['source_ids'] = source_ids['ids'] state['max_oov'] = source_ids['max_oov'] batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self.start_idx) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self.beam.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes) emb = self.target_embedder({'tokens': last_predictions}) state = self.decoder(emb, state) return Softmax(dim=-1)(state['class_logits'].squeeze(1)).log(), state def _encode( self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Encode the states tokens :param source_tokens: The indexes of states tokens :return: All the states and the last state """ state = self.init_enc_state(source_tokens) # (Batch, Seq, Emb Dim) embedded_src = self.source_embedder(source_tokens) # final_state = (last state, last context) states, final_state = self.encoder(embedded_src, state['source_lengths']) state['encoder_states'] = states # (B, L, Num Direction * D_h) state['hidden'] = final_state # (B, L, Num Direction * D_h) assert state['encoder_states'].size(2) == (2 * self.hidden_size) return state def _decode(self, source_ids: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.Tensor], state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Decode the encoder state :param target_tokens: The indexes of target tokens :param enc_states: All the encoder states :param enc_state: The last encoder state :return: The output of decoder, attentions and last decoding state """ state = self.init_dec_state(state) state['source_ids'] = source_ids['ids'] state['max_oov'] = source_ids['max_oov'] all_class_logits = [] all_coverages = [] all_attentions = [] # Teacher Forcing if torch.rand(1).item() <= self.teacher_force_ratio: embedded_tgt = self.target_embedder(target_tokens) for step, emb in enumerate(embedded_tgt.split(1, dim=1)): state = self.decoder(emb, state) all_class_logits.append(state['class_logits']) all_coverages.append(state['coverage']) all_attentions.append(state['attention']) else: tokens = state["encoder_states"].new_full( (state["encoder_states"].size(0), ), fill_value=self.start_idx, dtype=torch.long) emb = self.target_embedder({'tokens': tokens}) for step in range(self.max_steps): state = self.decoder(emb, state) all_class_logits.append(state['class_logits']) all_coverages.append(state['coverage']) all_attentions.append(state['attention']) # prob_dist = Categorical(Softmax(dim=-1)(all_class_logits[-1])) # tokens = prob_dist.sample() _, tokens = torch.topk( Softmax(dim=-1)(all_class_logits[-1]), 1) tokens[tokens >= self.vocab.get_vocab_size()] = self.unk_idx emb = self.target_embedder({'tokens': tokens.squeeze(1)}) # print(predicted_tokens) state['all_class_logits'] = torch.cat(all_class_logits, dim=1) state['all_coverages'] = torch.cat(all_coverages, dim=1) state['all_attentions'] = torch.cat(all_attentions, dim=1) state.pop('class_logits', None) state.pop('coverage', None) state.pop('attention', None) return state def _compute_loss(self, target_tokens: Dict[str, torch.Tensor], state: Dict[str, torch.Tensor]): # (B, L, V) all_class_logits = state['all_class_logits'].transpose(1, 2).contiguous() attentions = state['all_attentions'] coverages = state['all_coverages'] tokens = target_tokens['tokens'][:, 1:] batch_size = tokens.size(0) dim = all_class_logits.size(2) - 1 pad_tokens = all_class_logits.new_full((all_class_logits.size(0), dim), fill_value=self.padding_idx, dtype=torch.long) pad_tokens[:, :tokens.size(1)] = tokens # (B, L, 1) loss = self.criterion(all_class_logits[:, :, :-1], pad_tokens) coverage_loss = torch.min(attentions, coverages).sum() / batch_size total_loss = loss + coverage_loss return total_loss
class BeamSearchTest(AllenNlpTestCase): def setup_method(self): super().setup_method() self.end_index = transition_probabilities.size()[0] - 1 self.beam_search = BeamSearch(self.end_index, max_steps=10, beam_size=3) # This is what the top k should look like for each item in the batch. self.expected_top_k = np.array([[1, 2, 3, 4, 5], [2, 3, 4, 5, 5], [3, 4, 5, 5, 5]]) # This is what the log probs should look like for each item in the batch. self.expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) def _check_results( self, batch_size: int = 5, expected_top_k: np.array = None, expected_log_probs: np.array = None, beam_search: BeamSearch = None, state: Dict[str, torch.Tensor] = None, take_step=take_step_with_timestep, ) -> None: expected_top_k = expected_top_k if expected_top_k is not None else self.expected_top_k expected_log_probs = (expected_log_probs if expected_log_probs is not None else self.expected_log_probs) state = state or {} beam_search = beam_search or self.beam_search beam_size = beam_search.beam_size initial_predictions = torch.tensor([0] * batch_size) top_k, log_probs = beam_search.search(initial_predictions, state, take_step) # type: ignore # top_k should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(top_k.size())[:-1] == [batch_size, beam_size] np.testing.assert_array_equal(top_k[0].numpy(), expected_top_k) # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(log_probs.size()) == [batch_size, beam_size] np.testing.assert_allclose(log_probs[0].numpy(), expected_log_probs, rtol=1e-6) @pytest.mark.parametrize("step_function", [take_step_with_timestep, take_step_no_timestep]) def test_search(self, step_function): self._check_results(take_step=step_function) def test_finished_state(self): state = {} state["foo"] = torch.tensor([[1, 0, 1], [2, 0, 1], [0, 0, 1], [1, 1, 1], [0, 0, 0]]) # shape: (batch_size, 3) expected_finished_state = {} expected_finished_state["foo"] = np.array([ [1, 0, 1], [1, 0, 1], [1, 0, 1], [2, 0, 1], [2, 0, 1], [2, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], ]) # shape: (batch_size x beam_size, 3) self._check_results(state=state) # check finished state. for key, array in expected_finished_state.items(): np.testing.assert_allclose(state[key].numpy(), array) def test_diff_shape_state(self): state = {} state["decoder_hidden"] = torch.tensor([[1, 0, 1], [2, 0, 1], [0, 0, 1], [1, 1, 1], [0, 0, 0]]) state["decoder_hidden"] = state["decoder_hidden"].unsqueeze(0).repeat( 2, 1, 1) # shape: (2, batch_size, 3) seq = [ [1, 0, 1], [1, 0, 1], [1, 0, 1], [2, 0, 1], [2, 0, 1], [2, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], ] seq = [seq] * 2 expected_finished_state = {} expected_finished_state["decoder_hidden"] = np.array(seq) # shape: (2, batch_size x beam_size, 3) self._check_results(state=state) # check finished state. for key, array in expected_finished_state.items(): np.testing.assert_allclose(state[key].numpy(), array) def test_batch_size_of_one(self): self._check_results(batch_size=1) def test_greedy_search(self): beam_search = BeamSearch(self.end_index, beam_size=1) expected_top_k = np.array([[1, 2, 3, 4, 5]]) expected_log_probs = np.log(np.array([0.4])) self._check_results( expected_top_k=expected_top_k, expected_log_probs=expected_log_probs, beam_search=beam_search, ) def test_single_step(self): self.beam_search.max_steps = 1 expected_top_k = np.array([[1], [2], [3]]) expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) self._check_results( expected_top_k=expected_top_k, expected_log_probs=expected_log_probs, ) def test_early_stopping(self): """ Checks case where beam search will reach `max_steps` before finding end tokens. """ beam_search = BeamSearch(self.end_index, beam_size=3, max_steps=3) expected_top_k = np.array([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) self._check_results( expected_top_k=expected_top_k, expected_log_probs=expected_log_probs, beam_search=beam_search, ) def test_take_short_sequence_step(self): """ Tests to ensure the top-k from the short_sequence_transition_probabilities transition matrix is expected """ self.beam_search.beam_size = 5 expected_top_k = np.array([[5, 5, 5, 5, 5], [1, 5, 5, 5, 5], [1, 2, 5, 5, 5], [1, 2, 3, 5, 5], [1, 2, 3, 4, 5]]) expected_log_probs = np.log( np.array([0.9, 0.09, 0.009, 0.0009, 0.0001])) self._check_results( expected_top_k=expected_top_k, expected_log_probs=expected_log_probs, take_step=take_short_sequence_step, ) def test_min_steps(self): """ Tests to ensure all output sequences are greater than a specified minimum length. It uses the `take_short_sequence_step` step function, which favors shorter sequences. See `test_take_short_sequence_step`. """ self.beam_search.beam_size = 1 # An empty sequence is allowed under this step function self.beam_search.min_steps = 0 expected_top_k = np.array([[5]]) expected_log_probs = np.log(np.array([0.9])) with pytest.warns(RuntimeWarning, match="Empty sequences predicted"): self._check_results( expected_top_k=expected_top_k, expected_log_probs=expected_log_probs, take_step=take_short_sequence_step, ) self.beam_search.min_steps = 1 expected_top_k = np.array([[1, 5]]) expected_log_probs = np.log(np.array([0.09])) self._check_results( expected_top_k=expected_top_k, expected_log_probs=expected_log_probs, take_step=take_short_sequence_step, ) self.beam_search.min_steps = 2 expected_top_k = np.array([[1, 2, 5]]) expected_log_probs = np.log(np.array([0.009])) self._check_results( expected_top_k=expected_top_k, expected_log_probs=expected_log_probs, take_step=take_short_sequence_step, ) self.beam_search.beam_size = 3 self.beam_search.min_steps = 2 expected_top_k = np.array([[1, 2, 5, 5, 5], [1, 2, 3, 5, 5], [1, 2, 3, 4, 5]]) expected_log_probs = np.log(np.array([0.009, 0.0009, 0.0001])) self._check_results( expected_top_k=expected_top_k, expected_log_probs=expected_log_probs, take_step=take_short_sequence_step, ) def test_different_per_node_beam_size(self): # per_node_beam_size = 1 beam_search = BeamSearch(self.end_index, beam_size=3, per_node_beam_size=1) self._check_results(beam_search=beam_search) # per_node_beam_size = 2 beam_search = BeamSearch(self.end_index, beam_size=3, per_node_beam_size=2) self._check_results(beam_search=beam_search) def test_catch_bad_config(self): """ If `per_node_beam_size` (which defaults to `beam_size`) is larger than the size of the target vocabulary, `BeamSearch.search` should raise a ConfigurationError. """ beam_search = BeamSearch(self.end_index, beam_size=20) with pytest.raises(ConfigurationError): self._check_results(beam_search=beam_search) def test_warn_for_bad_log_probs(self): # The only valid next step from the initial predictions is the end index. # But with a beam size of 3, the call to `topk` to find the 3 most likely # next beams will result in 2 new beams that are invalid, in that have probability of 0. # The beam search should warn us of this. initial_predictions = torch.LongTensor( [self.end_index - 1, self.end_index - 1]) with pytest.warns(RuntimeWarning, match="Negligible log probabilities"): self.beam_search.search(initial_predictions, {}, take_step_no_timestep) def test_empty_sequences(self): initial_predictions = torch.LongTensor( [self.end_index - 1, self.end_index - 1]) beam_search = BeamSearch(self.end_index, beam_size=1) with pytest.warns(RuntimeWarning, match="Empty sequences predicted"): predictions, log_probs = beam_search.search( initial_predictions, {}, take_step_with_timestep) # predictions hould have shape `(batch_size, beam_size, max_predicted_length)`. assert list(predictions.size()) == [2, 1, 1] # log probs hould have shape `(batch_size, beam_size)`. assert list(log_probs.size()) == [2, 1] assert (predictions == self.end_index).all() assert (log_probs == 0).all() def test_default_from_params_params(self): beam_search = BeamSearch.from_params( Params({ "beam_size": 2, "end_index": 7 })) assert beam_search.beam_size == 2 assert beam_search._end_index == 7 def test_top_p_search(self): initial_predictions = torch.tensor([0] * 5) beam_size = 3 take_step = take_step_with_timestep p_sampler = TopPSampler(p=0.8) top_p, log_probs = BeamSearch(self.end_index, beam_size=beam_size, max_steps=10, sampler=p_sampler).search( initial_predictions, {}, take_step) beam_size = beam_size or 1 batch_size = 5 # top_p should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(top_p.size())[:-1] == [batch_size, beam_size] assert ((0 <= top_p) & (top_p <= 5)).all() # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(log_probs.size()) == [batch_size, beam_size] @pytest.mark.parametrize("p_val", [-1.0, 1.2, 1.1, float("inf")]) def test_p_val(self, p_val): with pytest.raises(ValueError): initial_predictions = torch.tensor([0] * 5) take_step = take_step_with_timestep beam_size = 3 p_sampler = TopPSampler(p=p_val, with_replacement=True) top_k, log_probs = BeamSearch(self.end_index, beam_size=beam_size, max_steps=10, sampler=p_sampler).search( initial_predictions, {}, take_step) def test_top_k_search(self): initial_predictions = torch.tensor([0] * 5) beam_size = 3 take_step = take_step_with_timestep k_sampler = TopKSampler(k=5, with_replacement=True) top_k, log_probs = BeamSearch(self.end_index, beam_size=beam_size, max_steps=10, sampler=k_sampler).search( initial_predictions, {}, take_step) beam_size = beam_size or 1 batch_size = 5 # top_p should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(top_k.size())[:-1] == [batch_size, beam_size] assert ((0 <= top_k) & (top_k <= 5)).all() # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(log_probs.size()) == [batch_size, beam_size] @pytest.mark.parametrize("k_val", [-1, 0]) def test_k_val(self, k_val): with pytest.raises(ValueError): initial_predictions = torch.tensor([0] * 5) take_step = take_step_with_timestep beam_size = 3 k_sampler = TopKSampler(k=k_val, with_replacement=True) top_k, log_probs = BeamSearch(self.end_index, beam_size=beam_size, max_steps=10, sampler=k_sampler).search( initial_predictions, {}, take_step) def test_stochastic_beam_search(self): initial_predictions = torch.tensor([0] * 5) batch_size = 5 beam_size = 3 take_step = take_step_with_timestep gumbel_sampler = GumbelSampler() top_k, log_probs = BeamSearch(self.end_index, beam_size=beam_size, max_steps=10, sampler=gumbel_sampler).search( initial_predictions, {}, take_step) # top_p should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(top_k.size())[:-1] == [batch_size, beam_size] assert ((0 <= top_k) & (top_k <= 5)).all() # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(log_probs.size()) == [batch_size, beam_size] # Check to make sure that once the end index is predicted, all subsequent tokens # must be the end index. This has been tested on toy examples in which for batch in top_k: for beam in batch: reached_end = False for token in beam: if token == self.end_index: reached_end = True if reached_end: assert token == self.end_index def test_params_sampling(self): beam_search = BeamSearch.from_params( Params({ "sampler": { "type": "top-k", "k": 4, }, "beam_size": 2, "end_index": 7, })) assert beam_search.beam_size == 2 assert beam_search._end_index == 7 assert beam_search.sampler is not None def test_params_p_sampling(self): beam_search = BeamSearch.from_params( Params({ "sampler": { "type": "top-p", "p": 0.8, }, "beam_size": 2, "end_index": 7, })) assert beam_search.beam_size == 2 assert beam_search._end_index == 7 assert beam_search.sampler is not None def test_multinomial_sampler(self): sampler = MultinomialSampler(temperature=0.9) probabilities, classes, state = sampler.sample_nodes( log_probabilities, 3, {"foo": "bar"}) assert probabilities.size() == classes.size() assert classes.size() == (2, 3) assert all([x < 4 for x in classes[0]]) assert all([x > 1 for x in classes[1]]) def test_top_k_sampler(self): sampler = TopKSampler(k=3, temperature=0.9) probabilities, classes, state = sampler.sample_nodes( log_probabilities, 3, {"foo": "bar"}) assert probabilities.size() == classes.size() assert classes.size() == (2, 3) assert all([x > 0 and x < 4 for x in classes[0]]) assert all([x > 1 and x < 5 for x in classes[1]]) def test_top_p_sampler(self): sampler = TopPSampler(p=0.8, temperature=0.9) probabilities, classes, state = sampler.sample_nodes( log_probabilities, 3, {"foo": "bar"}) assert probabilities.size() == classes.size() assert classes.size() == (2, 3) assert all([x > 0 and x < 4 for x in classes[0]]) assert all([x > 1 and x < 5 for x in classes[1]]) # Make sure the filtered classes include the first class that exceeds p sampler = TopPSampler(p=0.7, temperature=1.0) probabilities, classes, state = sampler.sample_nodes( log_probabilities, 2, {"foo": "bar"}) assert all([x == 2 or x == 3 or x == 1 for x in classes[0]]) assert all([x == 2 or x == 3 for x in classes[1]]) def test_gumbel_sampler(self): sampler = GumbelSampler() num_classes = len(log_probabilities[0]) sampler_state = sampler.init_state(log_probabilities, batch_size=2, num_classes=num_classes) log_probs, indices, state = sampler.sample_beams( log_probabilities, 3, sampler_state) assert log_probs.size() == indices.size() assert indices.size() == (2, 3) # Make sure the probabilities are sorted. _, sorted_indices = log_probs.sort(dim=-1, descending=True) assert (sorted_indices == torch.arange(3).unsqueeze(0)).all() assert all([x >= 0 and x < 4 for x in indices[0]]) assert all([x > 1 and x <= 5 for x in indices[1]]) def test_length_normalized_sequence_log_prob_scorer(self): """ Tests to ensure the sequences are normalized by the correct values. The end token is included in the length. The start token is not. """ self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer( ) expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) length_normalization = np.array([5, 4, 3]) expected_scores = expected_log_probs / length_normalization self._check_results(expected_log_probs=expected_scores) # Introduce a length penalty length_penalty = 2.0 self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer( length_penalty=length_penalty) expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) length_normalization = np.array( [5**length_penalty, 4**length_penalty, 3**length_penalty]) expected_scores = expected_log_probs / length_normalization self._check_results(expected_log_probs=expected_scores) # Pick a length penalty so extreme that the order of the sequences is reversed length_penalty = -2.0 self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer( length_penalty=length_penalty) expected_top_k = np.array([[3, 4, 5, 5, 5], [2, 3, 4, 5, 5], [1, 2, 3, 4, 5]]) expected_log_probs = np.log(np.array([0.2, 0.3, 0.4])) length_normalization = np.array( [3**length_penalty, 4**length_penalty, 5**length_penalty]) expected_scores = expected_log_probs / length_normalization self._check_results(expected_top_k=expected_top_k, expected_log_probs=expected_scores) # Here, we set the max_steps = 4. This prevents the first sequence from finishing, # so its length does not include the end token, whereas the other sequences do. length_penalty = 2.0 self.beam_search.max_steps = 4 self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer( length_penalty=length_penalty) expected_top_k = np.array([[1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 5]]) expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) length_normalization = np.array( [4**length_penalty, 4**length_penalty, 3**length_penalty]) expected_scores = expected_log_probs / length_normalization self._check_results(expected_top_k=expected_top_k, expected_log_probs=expected_scores) def test_repeated_ngram_blocking_constraint_init_state(self): ngram_size = 3 batch_size = 2 constraint = RepeatedNGramBlockingConstraint(ngram_size) state = constraint.init_state(batch_size) assert len(state) == batch_size for beam_states in state: assert len(beam_states) == 1 beam_state = beam_states[0] assert len(beam_state.keys()) == 2 assert len(beam_state["current_prefix"]) == 0 assert len(beam_state["seen_ngrams"]) == 0 def test_repeated_ngram_blocking_constraint_apply(self): ngram_size = 3 batch_size = 2 beam_size = 2 num_classes = 10 constraint = RepeatedNGramBlockingConstraint(ngram_size) state = [ [ { "current_prefix": [0, 1], "seen_ngrams": {} }, { "current_prefix": [2, 3], "seen_ngrams": { (2, 3): [4] } }, ], [ { "current_prefix": [4, 5], "seen_ngrams": { (8, 9): [] } }, { "current_prefix": [6, 7], "seen_ngrams": { (6, 7): [0, 1, 2] } }, ], ] log_probabilities = torch.rand(batch_size, beam_size, num_classes) constraint.apply(state, log_probabilities) disallowed_locations = torch.nonzero( log_probabilities == min_value_of_dtype( log_probabilities.dtype)).tolist() assert len(disallowed_locations) == 4 assert [0, 1, 4] in disallowed_locations assert [1, 1, 0] in disallowed_locations assert [1, 1, 1] in disallowed_locations assert [1, 1, 2] in disallowed_locations def test_repeated_ngram_blocking_constraint_update_state(self): ngram_size = 3 constraint = RepeatedNGramBlockingConstraint(ngram_size) # We will have [2, 3] -> {5, 6} from batch index 0 and [4, 5] -> {0} and [6, 7] -> {3} # from batch index state = [ [ { "current_prefix": [0, 1], "seen_ngrams": {} }, { "current_prefix": [2, 3], "seen_ngrams": { (2, 3): [4] } }, ], [ { "current_prefix": [4, 5], "seen_ngrams": { (8, 9): [] } }, { "current_prefix": [6, 7], "seen_ngrams": { (6, 7): [0, 1, 2] } }, ], ] predictions = torch.LongTensor([[5, 6], [0, 3]]) backpointers = torch.LongTensor([[1, 1], [0, 1]]) expected_state = [ [ { "current_prefix": [3, 5], "seen_ngrams": { (2, 3): [4, 5] } }, { "current_prefix": [3, 6], "seen_ngrams": { (2, 3): [4, 6] } }, ], [ { "current_prefix": [5, 0], "seen_ngrams": { (8, 9): [], (4, 5): [0] } }, { "current_prefix": [7, 3], "seen_ngrams": { (6, 7): [0, 1, 2, 3] } }, ], ] updated_state = constraint.update_state(state, predictions, backpointers) assert updated_state == expected_state def test_take_repeated_ngram_step(self): """ Tests to ensure the top-k from the `repeated_ngram_transition_probabilities_0` transition matrix is expected. The transitions are: - p(1|start) = 1.0 - p(2|1) = 0.4 - p(3|1) = 0.6 - p(end|1) = 1e-9 - p(3|2) = 1.0 - p(end|2) = 1e-9 - p(1|3) = 1.0 - p(end|3) = 1e-9 The probabilities don't add up 1 because of the 1e-9 transitions to end. That doesn't really matter. Each state just needed some transition to the end probability with a very small probability to ensure it's possible to reach the end state from there and that it isn't selected by beam search without a constraint. Below is the beam search tracing for beam size 2. Any sequence below the line is not selected by beam search. The number that comes before the sequence is the probability of the sequence. Step 1 1.0: [1] Step 2 0.6: [1, 3] 0.4: [1, 2] ----- 1e-9: [1, 2, end] Step 3 0.6: [1, 3, 1] 0.4: [1, 2, 3] ----- 0.6 * 1e-9: [1, 3, end] 0.4 * 1e-9: [1, 2, end] Step 4 0.4: [1, 2, 3, 1] 0.36: [1, 3, 1, 3] ----- 0.24: [1, 3, 1, 2] 0.6 * 1e-9: [1, 3, 1, end] 0.4 * 1e-9: [1, 2, 3, end] Step 5 0.36: [1, 3, 1, 3, 1] 0.24: [1, 2, 3, 1, 3] ----- 0.16: [1, 2, 3, 1, 2] 0.4 * 1e-9: [1, 2, 3, 1, end] 0.36 * 1e-9: [1, 3, 1, 3, end] """ step_function = get_step_function( repeated_ngram_transition_probabilities_0) self.beam_search.beam_size = 2 self.beam_search.max_steps = 5 expected_top_k = np.array([[1, 3, 1, 3, 1], [1, 2, 3, 1, 3]]) expected_log_probs = np.log(np.array([0.36, 0.24])) self._check_results( expected_top_k=expected_top_k, expected_log_probs=expected_log_probs, take_step=step_function, ) def test_repeated_ngram_blocking_end_to_end_unigrams(self): step_function = get_step_function( repeated_ngram_transition_probabilities_0) self.beam_search.beam_size = 2 # Unigrams: On step 3, [1, 3, 1] will be blocked and [1, 3, end] will take its place self.beam_search.max_steps = 3 self.beam_search.constraints = [ RepeatedNGramBlockingConstraint(ngram_size=1) ] expected_top_k = np.array([[1, 2, 3], [1, 3, 5]]) expected_log_probs = np.log(np.array([0.4, 0.6 * 1e-9])) self._check_results( expected_top_k=expected_top_k, expected_log_probs=expected_log_probs, take_step=step_function, ) step_function = get_step_function( repeated_ngram_transition_probabilities_1) self.beam_search.max_steps = 5 expected_top_k = np.array([[1, 2, 3, 4, 5], [1, 2, 4, 3, 5]]) expected_log_probs = np.log( np.array( [0.4 * 0.3 * 0.3 * 0.2 * 0.1, 0.4 * 0.3 * 0.2 * 0.3 * 0.1])) self._check_results( expected_top_k=expected_top_k, expected_log_probs=expected_log_probs, take_step=step_function, ) def test_repeated_ngram_blocking_end_to_end_bigrams(self): step_function = get_step_function( repeated_ngram_transition_probabilities_0) self.beam_search.beam_size = 2 # Bigrams: On step 4, [1, 3, 1, 3] will be blocked and [1, 3, 1, 2] will take its place self.beam_search.max_steps = 4 self.beam_search.constraints = [ RepeatedNGramBlockingConstraint(ngram_size=2) ] expected_top_k = np.array([[1, 2, 3, 1], [1, 3, 1, 2]]) expected_log_probs = np.log(np.array([0.4, 0.24])) self._check_results( expected_top_k=expected_top_k, expected_log_probs=expected_log_probs, take_step=step_function, ) def test_repeated_ngram_blocking_end_to_end_trigrams(self): step_function = get_step_function( repeated_ngram_transition_probabilities_0) self.beam_search.beam_size = 2 # Trigrams: On step 5, [1, 3, 1, 3, 1] will be blocked and [1, 2, 3, 1, 2] will take its place self.beam_search.max_steps = 5 self.beam_search.constraints = [ RepeatedNGramBlockingConstraint(ngram_size=3) ] expected_top_k = np.array([[1, 2, 3, 1, 3], [1, 2, 3, 1, 2]]) expected_log_probs = np.log(np.array([0.24, 0.16])) self._check_results( expected_top_k=expected_top_k, expected_log_probs=expected_log_probs, take_step=step_function, ) def test_repeated_ngram_blocking_end_indices(self): """ Ensures that the ngram blocking does not mess up when one sequence is shorter than another, which would result in repeated "end" symbols. """ # We block unigrams, but 5 (the end symbol) is repeated and it does not mess # up the sequence's probability step_function = get_step_function( repeated_ngram_transition_probabilities_0) self.beam_search.beam_size = 2 self.beam_search.constraints = [ RepeatedNGramBlockingConstraint(ngram_size=1) ] expected_top_k = np.array([[1, 3, 5, 5], [1, 2, 3, 5]]) expected_log_probs = np.log(np.array([0.6 * 1e-9, 0.4 * 1e-9])) self._check_results( expected_top_k=expected_top_k, expected_log_probs=expected_log_probs, take_step=step_function, )
class AssociativeSeq2SeqHiddenDiff(Model): def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, target_embedder: TextFieldEmbedder, source_encoder: Seq2VecEncoder, target_encoder: Seq2SeqEncoder, max_decoding_steps: int, attention: Attention = None, beam_size: int = None, target_namespace: str = "tokens", scheduled_sampling_ratio: float = 0., use_bleu: bool = True) -> None: super(AssociativeSeq2SeqHiddenDiff, self).__init__(vocab) self._target_namespace = target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={ pad_index, self._end_index, self._start_index }) else: self._bleu = None # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source vocab tokens. self._source_embedder = source_embedder # Encodes the sequence of source embeddings into a sequence of hidden states. self._source_encoder = source_encoder self._target_encoder = target_encoder self._encoder_output_dim = self._target_encoder.get_output_dim() self._decoder_output_dim = self._encoder_output_dim target_embedding_dim = source_embedder.get_output_dim() if attention: self._attention = attention self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim else: self._attention = None self._decoder_input_dim = target_embedding_dim + self._source_encoder.get_output_dim( ) num_classes = self.vocab.get_vocab_size(self._target_namespace) self._target_embedder = target_embedder self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections( last_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def forward( self, # type: ignore ref_source_tokens: Dict[str, torch.LongTensor], instance_source_tokens: Dict[str, torch.LongTensor], ref_target_tokens: Dict[str, torch.LongTensor], instance_target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: state = self._encode(ref_target_tokens, ref_source_tokens, instance_source_tokens) if instance_target_tokens: state = self._init_decoder_state(state) # The `_forward_loop` decodes the input sequence and computes the loss during training # and validation. output_dict = self._forward_loop(state, instance_target_tokens) else: output_dict = {} if not self.training: state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) if instance_target_tokens and self._bleu: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._bleu(best_predictions, instance_target_tokens["tokens"]) return output_dict @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _encode( self, ref_target_tokens: Dict[str, torch.Tensor], ref_source_tokens: Dict[str, torch.Tensor], instance_source_tokens: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_ref_target = self._target_embedder(ref_target_tokens) ref_target_mask = util.get_text_field_mask(ref_target_tokens) encoded_ref_target = self._target_encoder(embedded_ref_target, ref_target_mask) embedded_ref_source = self._source_embedder(ref_source_tokens) ref_source_mask = util.get_text_field_mask(ref_source_tokens) encoded_ref_source = self._source_encoder(embedded_ref_source, ref_source_mask) embedded_instance_source = self._source_embedder( instance_source_tokens) instance_source_mask = util.get_text_field_mask(instance_source_tokens) encoded_instance_source = self._source_encoder( embedded_instance_source, instance_source_mask) instance_ref_diff = encoded_ref_source - encoded_instance_source #print('mask',instance_source_mask.shape) #print('out',target_vectors.shape) return { "source_mask": ref_target_mask, "encoder_outputs": encoded_ref_target, "instance_ref_diff": instance_ref_diff } def _init_decoder_state( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( state["encoder_outputs"], state["source_mask"], self._target_encoder.is_bidirectional()) # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = final_encoder_output # shape: (batch_size, decoder_output_dim) state["decoder_context"] = state["encoder_outputs"].new_zeros( batch_size, self._decoder_output_dim) return state def _forward_loop( self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size()[0] if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full((batch_size, ), fill_value=self._start_index) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand( 1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size,) input_choices = last_predictions elif not target_tokens: # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] # shape: (batch_size, num_classes) output_projections, state = self._prepare_output_projections( input_choices, state) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape: (batch_size, num_classes) class_probabilities = F.softmax(output_projections, dim=-1) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(class_probabilities, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) output_dict = {"predictions": predictions} if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss return output_dict def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _prepare_output_projections(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # pylint: disable=line-too-long """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (group_size, decoder_output_dim) decoder_context = state["decoder_context"] # shape: (group_size, target_embedding_dim) embedded_input = self._target_embedder._token_embedders['tokens']( last_predictions) if self._attention: # shape: (group_size, encoder_output_dim) attended_input = self._prepare_attended_input( decoder_hidden, encoder_outputs, source_mask) # shape: (group_size, decoder_output_dim + target_embedding_dim) decoder_input = torch.cat((attended_input, embedded_input), -1) else: # shape: (group_size, target_embedding_dim) decoder_input = torch.cat( (embedded_input, state['instance_ref_diff']), -1) # shape (decoder_hidden): (batch_size, decoder_output_dim) # shape (decoder_context): (batch_size, decoder_output_dim) decoder_hidden, decoder_context = self._decoder_cell( decoder_input, (decoder_hidden, decoder_context)) state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context # shape: (group_size, num_classes) output_projections = self._output_projection_layer(decoder_hidden) return output_projections, state def _prepare_attended_input( self, decoder_hidden_state: torch.LongTensor = None, encoder_outputs: torch.LongTensor = None, encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor: """Apply attention over encoder outputs and decoder state.""" # Ensure mask is also a FloatTensor. Or else the multiplication within # attention will complain. # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs_mask = encoder_outputs_mask.float() # shape: (batch_size, max_input_sequence_length) input_weights = self._attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask) #print(torch.argmax(input_weights)) # shape: (batch_size, encoder_output_dim) attended_input = util.weighted_sum(encoder_outputs, input_weights) return attended_input @staticmethod def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu and not self.training: all_metrics.update(self._bleu.get_metric(reset=reset)) return all_metrics
class BeamSearchTest(AllenNlpTestCase): def setup_method(self): super().setup_method() self.end_index = transition_probabilities.size()[0] - 1 self.beam_search = BeamSearch(self.end_index, max_steps=10, beam_size=3) # This is what the top k should look like for each item in the batch. self.expected_top_k = np.array([[1, 2, 3, 4, 5], [2, 3, 4, 5, 5], [3, 4, 5, 5, 5]]) # This is what the log probs should look like for each item in the batch. self.expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) def _check_results( self, batch_size: int = 5, expected_top_k: np.array = None, expected_log_probs: np.array = None, beam_search: BeamSearch = None, state: Dict[str, torch.Tensor] = None, take_step=take_step_with_timestep, ) -> None: expected_top_k = expected_top_k if expected_top_k is not None else self.expected_top_k expected_log_probs = (expected_log_probs if expected_log_probs is not None else self.expected_log_probs) state = state or {} beam_search = beam_search or self.beam_search beam_size = beam_search.beam_size initial_predictions = torch.tensor([0] * batch_size) top_k, log_probs = beam_search.search(initial_predictions, state, take_step) # type: ignore # top_k should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(top_k.size())[:-1] == [batch_size, beam_size] np.testing.assert_array_equal(top_k[0].numpy(), expected_top_k) # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(log_probs.size()) == [batch_size, beam_size] np.testing.assert_allclose(log_probs[0].numpy(), expected_log_probs) @pytest.mark.parametrize("step_function", [take_step_with_timestep, take_step_no_timestep]) def test_search(self, step_function): self._check_results(take_step=step_function) def test_finished_state(self): state = {} state["foo"] = torch.tensor([[1, 0, 1], [2, 0, 1], [0, 0, 1], [1, 1, 1], [0, 0, 0]]) # shape: (batch_size, 3) expected_finished_state = {} expected_finished_state["foo"] = np.array([ [1, 0, 1], [1, 0, 1], [1, 0, 1], [2, 0, 1], [2, 0, 1], [2, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], ]) # shape: (batch_size x beam_size, 3) self._check_results(state=state) # check finished state. for key, array in expected_finished_state.items(): np.testing.assert_allclose(state[key].numpy(), array) def test_diff_shape_state(self): state = {} state["decoder_hidden"] = torch.tensor([[1, 0, 1], [2, 0, 1], [0, 0, 1], [1, 1, 1], [0, 0, 0]]) state["decoder_hidden"] = state["decoder_hidden"].unsqueeze(0).repeat( 2, 1, 1) # shape: (2, batch_size, 3) seq = [ [1, 0, 1], [1, 0, 1], [1, 0, 1], [2, 0, 1], [2, 0, 1], [2, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], ] seq = [seq] * 2 expected_finished_state = {} expected_finished_state["decoder_hidden"] = np.array(seq) # shape: (2, batch_size x beam_size, 3) self._check_results(state=state) # check finished state. for key, array in expected_finished_state.items(): np.testing.assert_allclose(state[key].numpy(), array) def test_batch_size_of_one(self): self._check_results(batch_size=1) def test_greedy_search(self): beam_search = BeamSearch(self.end_index, beam_size=1) expected_top_k = np.array([[1, 2, 3, 4, 5]]) expected_log_probs = np.log(np.array([0.4])) self._check_results( expected_top_k=expected_top_k, expected_log_probs=expected_log_probs, beam_search=beam_search, ) def test_early_stopping(self): """ Checks case where beam search will reach `max_steps` before finding end tokens. """ beam_search = BeamSearch(self.end_index, beam_size=3, max_steps=3) expected_top_k = np.array([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) self._check_results( expected_top_k=expected_top_k, expected_log_probs=expected_log_probs, beam_search=beam_search, ) def test_different_per_node_beam_size(self): # per_node_beam_size = 1 beam_search = BeamSearch(self.end_index, beam_size=3, per_node_beam_size=1) self._check_results(beam_search=beam_search) # per_node_beam_size = 2 beam_search = BeamSearch(self.end_index, beam_size=3, per_node_beam_size=2) self._check_results(beam_search=beam_search) def test_catch_bad_config(self): """ If `per_node_beam_size` (which defaults to `beam_size`) is larger than the size of the target vocabulary, `BeamSearch.search` should raise a ConfigurationError. """ beam_search = BeamSearch(self.end_index, beam_size=20) with pytest.raises(ConfigurationError): self._check_results(beam_search=beam_search) def test_warn_for_bad_log_probs(self): # The only valid next step from the initial predictions is the end index. # But with a beam size of 3, the call to `topk` to find the 3 most likely # next beams will result in 2 new beams that are invalid, in that have probability of 0. # The beam search should warn us of this. initial_predictions = torch.LongTensor( [self.end_index - 1, self.end_index - 1]) with pytest.warns(RuntimeWarning, match="Infinite log probabilities"): self.beam_search.search(initial_predictions, {}, take_step_no_timestep) def test_empty_sequences(self): initial_predictions = torch.LongTensor( [self.end_index - 1, self.end_index - 1]) beam_search = BeamSearch(self.end_index, beam_size=1) with pytest.warns(RuntimeWarning, match="Empty sequences predicted"): predictions, log_probs = beam_search.search( initial_predictions, {}, take_step_with_timestep) # predictions hould have shape `(batch_size, beam_size, max_predicted_length)`. assert list(predictions.size()) == [2, 1, 1] # log probs hould have shape `(batch_size, beam_size)`. assert list(log_probs.size()) == [2, 1] assert (predictions == self.end_index).all() assert (log_probs == 0).all() def test_top_p_search(self): initial_predictions = torch.tensor([0] * 5) beam_size = 3 take_step = take_step_with_timestep top_p, log_probs = BeamSearch.top_p_sampling( self.end_index, beam_size=beam_size).search(initial_predictions, {}, take_step) # bem_search = BeamSearch(self.end_index, beam_size=3, per_node_beam_size = 1) # top_p, log_probs = beam_search.search(initial_predictions, {}, take_step) beam_size = beam_size or 1 batch_size = 5 # top_p should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(top_p.size())[:-1] == [batch_size, beam_size] assert ((0 <= top_p) & (top_p <= 5)).all() # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(log_probs.size()) == [batch_size, beam_size] def test_top_k_search(self): initial_predictions = torch.tensor([0] * 5) beam_size = 3 take_step = take_step_with_timestep top_k, log_probs = BeamSearch.top_k_sampling( self.end_index, k=1, beam_size=beam_size).search(initial_predictions, {}, take_step) beam_size = beam_size or 1 batch_size = 5 # top_p should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(top_k.size())[:-1] == [batch_size, beam_size] assert ((0 <= top_k) & (top_k <= 5)).all() # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(log_probs.size()) == [batch_size, beam_size] def test_empty_p(self): initial_predictions = torch.LongTensor( [self.end_index - 1, self.end_index - 1]) take_step = take_step_with_timestep with pytest.warns(RuntimeWarning, match="Empty sequences predicted"): predictions, log_probs = BeamSearch.top_p_sampling( self.end_index, beam_size=1).search(initial_predictions, {}, take_step) # predictions hould have shape `(batch_size, beam_size, max_predicted_length)`. assert list(predictions.size()) == [2, 1, 1] # log probs hould have shape `(batch_size, beam_size)`. assert list(log_probs.size()) == [2, 1] assert (predictions == self.end_index).all() assert (log_probs == 0).all() def test_empty_k(self): initial_predictions = torch.LongTensor( [self.end_index - 1, self.end_index - 1]) take_step = take_step_with_timestep with pytest.warns(RuntimeWarning, match="Empty sequences predicted"): predictions, log_probs = BeamSearch.top_k_sampling( self.end_index, beam_size=1).search(initial_predictions, {}, take_step) # predictions hould have shape `(batch_size, beam_size, max_predicted_length)`. assert list(predictions.size()) == [2, 1, 1] # log probs hould have shape `(batch_size, beam_size)`. assert list(log_probs.size()) == [2, 1] assert (predictions == self.end_index).all() assert (log_probs == 0).all() @pytest.mark.parametrize( "k", [-1.0, 1.2, 1.1, "foo", float("inf")], ) def test_k_val(self, k): with pytest.raises(ConfigurationError): initial_predictions = torch.tensor([0] * 5) take_step = take_step_with_timestep beam_size = 3 top_k, log_probs = BeamSearch.top_k_sampling( self.end_index, k=k, beam_size=beam_size).search(initial_predictions, {}, take_step) @pytest.mark.parametrize( "p", [-1.0, 1.1, 2, "foo", float("inf")], ) def test_p_val(self, p): with pytest.raises(ConfigurationError): initial_predictions = torch.tensor([0] * 5) take_step = take_step_with_timestep beam_size = 3 top_p, log_probs = BeamSearch.top_p_sampling( self.end_index, p=p, beam_size=beam_size).search(initial_predictions, {}, take_step) def test_params_no_sampling(self): beam_search = BeamSearch.from_params( Params({ "beam_size": 2, "end_index": 7 })) assert beam_search.beam_size == 2 assert beam_search._end_index == 7 assert beam_search.sampler is None def test_params_k_sampling(self): beam_search = BeamSearch.from_params( Params({ "type": "top_k_sampling", "beam_size": 2, "end_index": 7, "k": 5, })) assert beam_search.beam_size == 2 assert beam_search._end_index == 7 assert beam_search.sampler is not None def test_params_p_sampling(self): beam_search = BeamSearch.from_params( Params({ "type": "top_p_sampling", "beam_size": 2, "end_index": 7, "p": 0.4, })) assert beam_search.beam_size == 2 assert beam_search._end_index == 7 assert beam_search.sampler is not None
class ImageCaptioning(Model): def __init__(self, vocab: Vocabulary, max_timesteps: int = 50, encoder_size: int = 14, encoder_dim: int = 512, embedding_dim: int = 64, attention_dim: int = 64, decoder_dim: int = 64, beam_size: int = 3, teacher_forcing: bool = True) -> None: super().__init__(vocab) self._max_timesteps = max_timesteps self._vocab_size = self.vocab.get_vocab_size() self._start_index = self.vocab.get_token_index(START_SYMBOL) self._end_index = self.vocab.get_token_index(END_SYMBOL) # POSSIBLE CHANGE LATER self._pad_index = self.vocab.get_token_index('@@PADDING@@') self._encoder_size = encoder_size self._encoder_dim = encoder_dim self._embedding_dim = embedding_dim self._attention_dim = attention_dim self._decoder_dim = decoder_dim self._beam_size = beam_size self._teacher_forcing = teacher_forcing self._init_h = nn.Linear(self._encoder_dim, self._decoder_dim) self._init_c = nn.Linear(self._encoder_dim, self._decoder_dim) self._resnet = torchvision.models.resnet18() modules = list(self._resnet.children())[:-2] self._encoder = nn.Sequential( *modules, nn.AdaptiveAvgPool2d(self._encoder_size) ) self._decoder = ImageCaptioningDecoder(self._vocab_size, self._encoder_dim, self._embedding_dim, self._attention_dim, self._decoder_dim) self.beam_search = BeamSearch(self._end_index, self._max_timesteps, self._beam_size) self._bleu = BLEU(exclude_indices={self._start_index, self._end_index, self._pad_index}) self._exprate = Exprate(self._end_index) def _init_hidden(self, encoder: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: mean_encoder = encoder.mean(dim=1) # Shape: (batch_size, decoder_dim) initial_h = self._init_h(mean_encoder) # Shape: (batch_size, decoder_dim) initial_c = self._init_c(mean_encoder) return initial_h, initial_c def _decode(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # Get data from state x = state['x'] h = state['h'] c = state['c'] label = state['label'] mask = state['mask'] # Get actual size of current batch local_batch_size = x.shape[0] # Sort data to be able to only compute relevent parts of the batch at each timestep # Shape: (batch_size) lengths = mask.sum(dim=1) # Shape: (batch_size) (batch_size) sorted_lengths, indices = lengths.sort(dim=0, descending=True) # Computing last timestep isn't necessary with labels since last timestep is eos token or pad token timesteps = sorted_lengths[0] - 1 # Shape: (batch_size, height * width, encoder_dim) # Shape: (batch_size, decoder_dim) # Shape: (batch_size, decoder_dim) # Shape: (batch_size, timesteps) # Shape: (batch_size, timesteps) x = x[indices] h = h[indices] c = c[indices] label = label[indices] mask = mask[indices] # Shape: (batch_size, 1) predicted_indices = torch.LongTensor([[self._start_index]] * local_batch_size).to(device).view(-1, 1) # Shape: (batch_size, timesteps, vocab_size) predictions = torch.zeros(local_batch_size, timesteps, self._vocab_size, device=device) attention_weights = torch.zeros(local_batch_size, timesteps, self._encoder_size * self._encoder_size, device=device) for t in range(timesteps): # Shape: (batch_offset) batch_offset = sum([l > t for l in sorted_lengths.tolist()]) # Only compute data in valid timesteps # Shape: (batch_offset, height * width, encoder_dim) # Shape: (batch_offset, decoder_dim) # Shape: (batch_offset, decoder_dim) # Shape: (batch_offset, 1) x_t = x[:batch_offset] h_t = h[:batch_offset] c_t = c[:batch_offset] predicted_indices_t = predicted_indices[:batch_offset] # Decode timestep # Shape: (batch_size, decoder_dim) (batch_size, decoder_dim) (batch_size, vocab_size), (batch_size, encoder_dim, 1) h, c, preds, attention_weight = self._decoder(x_t, h_t, c_t, predicted_indices_t) # Get new predicted indices to pass into model at next timestep # Use teacher forcing if chosen if self._teacher_forcing: # Send next timestep's label to next timestep # Shape: (batch_size, 1) predicted_indices = label[:batch_offset, t + 1].view(-1, 1) else: # Shape: (batch_size, 1) predicted_indices = torch.argmax(preds, dim=1).view(-1, 1) # Save preds predictions[:batch_offset, t, :] = preds attention_weights[:batch_offset, t, :] = attention_weight.view(-1, self._encoder_size * self._encoder_size) # Update state and add logits state['x'] = x state['h'] = h state['c'] = c state['label'] = label state['mask'] = mask state['attention_weights'] = attention_weights state['logits'] = predictions return state def _beam_search_step(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # Group_size is batch_size * beam_size except for first decoding timestep where it is batch_size # Shape: (group_size, decoder_dim) (group_size, decoder_dim) (group_size, vocab_size) h, c, predictions, _ = self._decoder(state['x'], state['h'], state['c'], last_predictions) # Update state # Shape: (group_size, decoder_dim) state['h'] = h # Shape: (group_size, decoder_dim) state['c'] = c # Run log_softmax over logit predictions # Shape: (group_size, vocab_size) log_preds = F.log_softmax(predictions, dim=1) return log_preds, state def _beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # Get data from state x = state['x'] h = state['h'] c = state['c'] # Get actual size of current batch local_batch_size = x.shape[0] # Beam search wants initial preds of shape: (batch_size) # Shape: (batch_size) initial_indices = torch.LongTensor([[self._start_index]] * local_batch_size).to(device).view(-1) state = {'x': x, 'h': h, 'c': c} # Timesteps returned aren't necessarily max_timesteps # Shape: (batch_size, beam_size, timesteps), (batch_size, beam_size) predictions, log_probabilities = self.beam_search.search(initial_indices, state, self._beam_search_step) # Only keep best predictions from beam search # Shape: (batch_size, timesteps) predictions = predictions[:, 0, :].view(local_batch_size, -1) return predictions @overrides def forward(self, img: torch.Tensor, label: Dict[str, torch.Tensor] = None) -> Dict[str, torch.Tensor]: # Encode the image # Shape: (batch_size, encoder_dim, height, width) x = self._encoder(img) # Flatten image # Shape: (batch_size, height * width, encoder_dim) x = x.view(x.shape[0], -1, x.shape[1]) state = {'x': x} # Compute loss on train and val if label is not None: # Initialize h and c # Shape: (batch_size, decoder_dim) state['h'], state['c'] = self._init_hidden(x) # Convert label dict to tensor since label isn't an input to the model and get mask # Shape: (batch_size, timesteps) state['mask'] = get_text_field_mask(label).to(device) # Shape: (batch_size, timesteps) state['label'] = label['tokens'] # Decode encoded image and get loss on train and val state = self._decode(state) # Loss shouldn't be computed on start token state['mask'] = state['mask'][:, 1:].contiguous() state['target'] = state['label'][:, 1:].contiguous() # Compute cross entropy loss state['loss'] = sequence_cross_entropy_with_logits(state['logits'], state['target'], state['mask']) # Doubly stochastic regularization state['loss'] += ((1 - torch.sum(state['attention_weights'], dim=1)) ** 2).mean() # Decode encoded image with beam search on val and test if not self.training: # (Re)initialize h and c state['h'], state['c'] = self._init_hidden(x) # Run beam search state['out'] = self._beam_search(state) # Compute validation scores if 'label' in state: self._bleu(state['out'], state['target']) self._exprate(state['out'], state['target']) # Set out to logits while training else: state['out'] = state['logits'] # Create output_dict output_dict = {} output_dict['out'] = state['logits'] if self.training else state['out'] if 'loss' in state: output_dict['loss'] = state['loss'] return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics: Dict[str, float] = {} # Return Bleu score if possible if not self.training: metrics.update(self._bleu.get_metric(reset)) metrics.update(self._exprate.get_metric(reset)) return metrics def _trim_predictions(self, predictions: torch.Tensor) -> torch.Tensor: for b in range(predictions.shape[0]): # Shape: (timesteps) predicted_index = predictions[b] # Set last predicted index to eos token in case there are no predicted eos tokens predicted_index[-1] = self._end_index # Get index of first eos token # Shape: (timesteps) mask = predicted_index == self._end_index # Work around for pytorch not having an easy way to get the first non-zero index eos_token_idx = list(mask.cpu().numpy()).index(1) # Set prediction at eos token's timestep to eos token predictions[b, eos_token_idx] = self._end_index # Replace all timesteps after first eos token with pad token predictions[b, eos_token_idx + 1:] = self._pad_index return predictions @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # Trim test preds to first eos token # Shape: (batch_size, timesteps) output_dict['out'] = self._trim_predictions(output_dict['out']) return output_dict
class CopyNet(Model): """ This is an implementation of `CopyNet <https://arxiv.org/pdf/1603.06393>`_. CopyNet is a sequence-to-sequence encoder-decoder model with a copying mechanism that can copy tokens from the source sentence into the target sentence instead of generating all target tokens only from the target vocabulary. It is very similar to a typical seq2seq model used in neural machine translation tasks, for example, except that in addition to providing a "generation" score at each timestep for the tokens in the target vocabulary, it also provides a "copy" score for each token that appears in the source sentence. In other words, you can think of CopyNet as a seq2seq model with a dynamic target vocabulary that changes based on the tokens in the source sentence, allowing it to predict tokens that are out-of-vocabulary (OOV) with respect to the actual target vocab. Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. source_embedder : ``TextFieldEmbedder``, required Embedder for source side sequences encoder : ``Seq2SeqEncoder``, required The encoder of the "encoder/decoder" model attention : ``Attention``, required This is used to get a dynamic summary of encoder outputs at each timestep when producing the "generation" scores for the target vocab. beam_size : ``int``, required Beam width to use for beam search prediction. max_decoding_steps : ``int``, required Maximum sequence length of target predictions. target_embedding_dim : ``int``, optional (default = 30) The size of the embeddings for the target vocabulary. copy_token : ``str``, optional (default = '@COPY@') The token used to indicate that a target token was copied from the source. If this token is not already in your target vocabulary, it will be added. source_namespace : ``str``, optional (default = 'source_tokens') The namespace for the source vocabulary. target_namespace : ``str``, optional (default = 'target_tokens') The namespace for the target vocabulary. metric : ``Metric``, optional (default = BLEU) A metrics to track on a validation set. Note that this metric must accept three arguments when called: a batched tensor of predicted token indices, a batched tensor of gold token indices, and a set of token indices to exclude when calculating n-grams (usually should be the start index, end index, and pad index). """ def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, attention: Attention, beam_size: int, max_decoding_steps: int, target_embedding_dim: int = 30, copy_token: str = "@COPY@", source_namespace: str = "source_tokens", target_namespace: str = "target_tokens", metric: Metric = BLEU()) -> None: super(CopyNet, self).__init__(vocab) self._metric = metric self._source_namespace = source_namespace self._target_namespace = target_namespace self._src_start_index = self.vocab.get_token_index(START_SYMBOL, self._source_namespace) self._src_end_index = self.vocab.get_token_index(END_SYMBOL, self._source_namespace) self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) self._oov_index = self.vocab.get_token_index(self.vocab._oov_token, self._target_namespace) # pylint: disable=protected-access self._pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._copy_index = self.vocab.get_token_index(copy_token, self._target_namespace) if self._copy_index == self._oov_index: raise ConfigurationError(f"Special copy token {copy_token} missing from target vocab namespace. " f"You can ensure this token is added to the target namespace with the " f"vocabulary parameter 'tokens_to_add'.") self._target_vocab_size = self.vocab.get_vocab_size(self._target_namespace) # Encoding modules. self._source_embedder = source_embedder self._encoder = encoder # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. # We arbitrarily set the decoder's input dimension to be the same as the output dimension. self.encoder_output_dim = self._encoder.get_output_dim() self.decoder_output_dim = self.encoder_output_dim self.decoder_input_dim = self.decoder_output_dim target_vocab_size = self.vocab.get_vocab_size(self._target_namespace) # The decoder input will be a function of the embedding of the previous predicted token, # an attended encoder hidden state called the "attentive read", and another # weighted sum of the encoder hidden state called the "selective read". # While the weights for the attentive read are calculated by an `Attention` module, # the weights for the selective read are simply the predicted probabilities # corresponding to each token in the source sentence from the previous timestep. self._target_embedder = Embedding(target_vocab_size, target_embedding_dim) self._attention = attention self._input_projection_layer = Linear( target_embedding_dim + self.encoder_output_dim * 2, self.decoder_input_dim) # We then run the projected decoder input through an LSTM cell to produce # the next hidden state. self._decoder_cell = LSTMCell(self.decoder_input_dim, self.decoder_output_dim) # We create a "generation" score for each token in the target vocab # with a linear projection of the decoder hidden state. self._output_generation_layer = Linear(self.decoder_output_dim, target_vocab_size) # We create a "copying" score for each source token by applying a non-linearity # (tanh) to a linear projection of the encoded hidden state for that token, # and then taking the dot product of the result with the decoder hidden state. self._output_copying_layer = Linear(self.encoder_output_dim, self.decoder_output_dim) # At prediction time, we'll use a beam search to find the best target sequence. self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) @overrides def forward(self, # type: ignore source_tokens: Dict[str, torch.LongTensor], source_to_source: torch.Tensor, source_to_target: torch.Tensor, metadata: List[Dict[str, Any]], target_tokens: Dict[str, torch.LongTensor] = None, target_to_source: torch.Tensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Make foward pass with decoder logic for producing the entire target sequence. Parameters ---------- source_tokens : ``Dict[str, torch.LongTensor]``, required The output of `TextField.as_array()` applied on the source `TextField`. This will be passed through a `TextFieldEmbedder` and then through an encoder. source_to_source : ``torch.Tensor``, required Tensor containing indicators of which source tokens match each other. Has shape: `(batch_size, trimmed_source_length, trimmed_source_length)`. source_to_target : ``torch.Tensor``, required Tensor containing vocab index of each source token with respect to the target vocab namespace. Shape: `(batch_size, trimmed_source_length)`. target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) Output of `Textfield.as_array()` applied on target `TextField`. We assume that the target tokens are also represented as a `TextField`. target_to_source : ``torch.Tensor``, optional (default = None) A sparse tensor of shape `(batch_size, target_sequence_length, source_sentence_length - 2)` that indicates which tokens in the source sentence match each token in the target sequence. The last dimension is `source_sentence_length - 2` because we exclude the START_SYMBOL and END_SYMBOL in the source sentence (the source sentence is guaranteed to contain the START_SYMBOL and END_SYMBOL). Returns ------- Dict[str, torch.Tensor] """ state = self._encode(source_tokens, source_to_source, source_to_target) if target_tokens: state = self._init_decoder_state(state) output_dict = self._forward_loop(target_tokens, target_to_source, state) else: output_dict = {} output_dict["metadata"] = metadata if not self.training: state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) if self._metric and target_tokens: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] # shape: (batch_size, target_sequence_length) gold_tokens = self._gather_extended_gold_tokens(target_tokens["tokens"], target_to_source) self._metric(best_predictions, gold_tokens, (self._pad_index, self._end_index, self._start_index)) return output_dict def _gather_extended_gold_tokens(self, target_tokens: torch.LongTensor, target_to_source: torch.Tensor) -> torch.LongTensor: """ Modify the gold target tokens relative to the extended vocabulary. For gold targets that are OOV but were copied from the source, the OOV index will be changed to the index of the first occurence in the source sentence, offset by the size of the target vocabulary. Parameters ---------- target_tokens : ``torch.LongTensor`` Shape: `(batch_size, target_sequence_length)`. target_to_source : ``torch.Tensor`` Shape: `(batch_size, target_sequence_length, trimmed_source_length)`. Returns ------- torch.Tensor Modified `target_tokens` with OOV indices replaced by offset index of first match in source sentence. """ # Only change indices for tokens that were OOV in target vocab but copied from source. # shape: (batch_size, target_sequence_length) oov = (target_tokens == self._oov_index) # shape: (batch_size, target_sequence_length) copied = (target_to_source.sum(-1) > 0) # shape: (batch_size, target_sequence_length) mask = (oov & copied).long() # shape: (batch_size, target_sequence_length) _, first_match = target_to_source.max(-1) # shape: (batch_size, target_sequence_length) new_target_tokens = target_tokens * (1 - mask) + (first_match.long() + self._target_vocab_size) * mask return new_target_tokens def _init_decoder_state(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Initialize the encoded state to be passed to the first decoding time step. """ batch_size, _ = state["source_mask"].size() # Initialize the decoder hidden state with the final output of the encoder, # and the decoder context with zeros. # shape: (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( state["encoder_outputs"], state["source_mask"], self._encoder.is_bidirectional()) # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = final_encoder_output # shape: (batch_size, decoder_output_dim) state["decoder_context"] = state["encoder_outputs"].new_zeros(batch_size, self.decoder_output_dim) return state def _encode(self, source_tokens: Dict[str, torch.Tensor], source_to_source: torch.Tensor, source_to_target: torch.Tensor) -> Dict[str, torch.Tensor]: """ Encode source input sentences. """ # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_input = self._source_embedder(source_tokens) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = self._encoder(embedded_input, source_mask) state = { "source_mask": source_mask, "encoder_outputs": encoder_outputs, "source_to_source": source_to_source, "source_to_target": source_to_target, } return state def _decoder_step(self, last_predictions: torch.Tensor, selective_weights: torch.Tensor, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs_mask = state["source_mask"].float() # shape: (group_size, target_embedding_dim) embedded_input = self._target_embedder(last_predictions) # shape: (batch_size, max_input_sequence_length) attentive_weights = self._attention( state["decoder_hidden"], state["encoder_outputs"], encoder_outputs_mask) # shape: (batch_size, encoder_output_dim) attentive_read = util.weighted_sum(state["encoder_outputs"], attentive_weights) # shape: (batch_size, encoder_output_dim) selective_read = util.weighted_sum(state["encoder_outputs"][:, 1:-1], selective_weights) # shape: (group_size, target_embedding_dim + encoder_output_dim * 2) decoder_input = torch.cat((embedded_input, attentive_read, selective_read), -1) # shape: (group_size, decoder_input_dim) projected_decoder_input = self._input_projection_layer(decoder_input) state["decoder_hidden"], state["decoder_context"] = self._decoder_cell( projected_decoder_input, (state["decoder_hidden"], state["decoder_context"])) return state def _get_generation_scores(self, state: Dict[str, torch.Tensor]) -> torch.Tensor: return self._output_generation_layer(state["decoder_hidden"]) def _get_copy_scores(self, state: Dict[str, torch.Tensor]) -> torch.Tensor: # shape: (batch_size, max_input_sequence_length - 2, encoder_output_dim) trimmed_encoder_outputs = state["encoder_outputs"][:, 1:-1] # shape: (batch_size, max_input_sequence_length - 2, decoder_output_dim) copy_projection = self._output_copying_layer(trimmed_encoder_outputs) # shape: (batch_size, max_input_sequence_length - 2, decoder_output_dim) copy_projection = torch.tanh(copy_projection) # shape: (batch_size, max_input_sequence_length - 2) copy_scores = copy_projection.bmm(state["decoder_hidden"].unsqueeze(-1)).squeeze(-1) return copy_scores def _get_ll_contrib(self, generation_scores: torch.Tensor, copy_scores: torch.Tensor, target_tokens: torch.Tensor, target_to_source: torch.Tensor, copy_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Get the log-likelihood contribution from a single timestep. Parameters ---------- generation_scores : ``torch.Tensor`` Shape: `(batch_size, target_vocab_size)` copy_scores : ``torch.Tensor`` Shape: `(batch_size, trimmed_source_length)` target_tokens : ``torch.Tensor`` Shape: `(batch_size,)` target_to_source : ``torch.Tensor`` Shape: `(batch_size, trimmed_source_length)` copy_mask : ``torch.Tensor`` Shape: `(batch_size, trimmed_source_length)` Returns ------- Tuple[torch.Tensor, torch.Tensor] Shape: `(batch_size,), (batch_size, max_input_sequence_length)` """ _, target_size = generation_scores.size() # The point of this mask is to just mask out all source token scores # that just represent padding. We apply the mask to the concatenation # of the generation scores and the copy scores to normalize the scores # correctly during the softmax. # shape: (batch_size, target_vocab_size + trimmed_source_length) mask = torch.cat((generation_scores.new_full(generation_scores.size(), 1.0), copy_mask), dim=-1) # shape: (batch_size, target_vocab_size + trimmed_source_length) all_scores = torch.cat((generation_scores, copy_scores), dim=-1) # Normalize generation and copy scores. # shape: (batch_size, target_vocab_size + trimmed_source_length) probs = util.masked_softmax(all_scores, mask) # Calculate the probability (normalized copy score) for each token in the source sentence # that matches the current target token. We end up summing the scores # for each occurence of a matching token to get the actual score, but we also # need the un-summed probabilities to create the selective read state # during the next time step. # shape: (batch_size, trimmed_source_length) raw_selective_weights = probs[:, target_size:] * target_to_source.float() # shape: (batch_size,) sum_selective_weights = raw_selective_weights.sum(-1) # shape: (batch_size, trimmed_source_length) selective_weights = raw_selective_weights / (sum_selective_weights.unsqueeze(-1) + 1e-13) # This mask ensures that item in the batch has a non-zero generation score for this timestep # only when the gold target token is not OOV or there are no matching tokens # in the source sentence. # shape: (batch_size,) gen_mask = ((target_tokens != self._oov_index) | (target_to_source.sum(-1) == 0)).float() # Now we get the generation score for the gold target token. # shape: (batch_size,) step_likelihood = probs.gather(1, target_tokens.unsqueeze(1)).squeeze(-1) * gen_mask # ... and add the copy score. # shape: (batch_size,) step_likelihood = step_likelihood + sum_selective_weights # shape: (batch_size,) step_log_likelihood = step_likelihood.log() return step_log_likelihood, selective_weights def _forward_loop(self, target_tokens: Dict[str, torch.LongTensor], target_to_source: torch.Tensor, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Calculate the loss against gold targets. """ # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size, target_sequence_length = target_tokens["tokens"].size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 # We use this to fill in the copy index when the previous input was copied. # shape: (batch_size,) copy_input_choices = source_mask.new_full((batch_size,), fill_value=self._copy_index) # shape: (batch_size, trimmed_source_length) copy_mask = source_mask[:, 1:-1].float() # We need to keep track of the probabilities assigned to tokens in the source # sentence that were copied during the previous timestep, since we use # those probabilities as weights when calculating the "selective read". # shape: (batch_size, trimmed_source_length) selective_weights = state["decoder_hidden"].new_zeros(copy_mask.size()) step_log_likelihoods = [] for timestep in range(num_decoding_steps): # shape: (batch_size,) input_choices = target_tokens["tokens"][:, timestep] # If the previous target token was copied, we use the special copy token. # But the end target token will always be THE end token, so we know # it was not copied. if timestep < num_decoding_steps - 1: # Get mask tensor indicating which instances were copied. # shape: (batch_size,) copied = (target_to_source[:, timestep, :].sum(-1) > 0).long() # shape: (batch_size,) input_choices = input_choices * (1 - copied) + copy_input_choices * copied # Update the decoder state by taking a step through the RNN. state = self._decoder_step(input_choices, selective_weights, state) # Get generation scores for each token in the target vocab. # shape: (batch_size, target_vocab_size) generation_scores = self._get_generation_scores(state) # Get copy scores for each token in the source sentence, excluding the start # and end tokens. # shape: (batch_size, max_input_sequence_length - 2) copy_scores = self._get_copy_scores(state) # shape: (batch_size,) step_target_tokens = target_tokens["tokens"][:, timestep + 1] # shape: (batch_size, max_input_sequence_length - 2) step_target_to_source = target_to_source[:, timestep + 1] step_log_likelihood, selective_weights = self._get_ll_contrib( generation_scores, copy_scores, step_target_tokens, step_target_to_source, copy_mask) step_log_likelihoods.append(step_log_likelihood.unsqueeze(1)) # Gather step log-likelihoods. # shape: (batch_size, num_decoding_steps = target_sequence_length - 1) log_likelihoods = torch.cat(step_log_likelihoods, 1) # Get target mask to exclude likelihood contributions from timesteps after # the END token. # shape: (batch_size, target_sequence_length) target_mask = util.get_text_field_mask(target_tokens) # The first timestep is just the START token, which is not included in the likelihoods. # shape: (batch_size, num_decoding_steps) target_mask = target_mask[:, 1:].float() # Sum of step log-likelihoods. # shape: (batch_size,) log_likelihood = (log_likelihoods * target_mask).sum(dim=-1) # The loss is the negative log-likelihood, averaged over the batch. loss = - log_likelihood.sum() / batch_size return {"loss": loss} def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size, source_length = state["source_mask"].size() trimmed_source_length = source_length - 2 # Initialize the copy scores to zero. state["copy_probs"] = state["decoder_hidden"].new_zeros((batch_size, trimmed_source_length)) # shape: (batch_size,) start_predictions = state["source_mask"].new_full((batch_size,), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_search_step) output_dict = { "predicted_log_probs": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _get_input_and_selective_weights(self, last_predictions: torch.LongTensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.LongTensor, torch.Tensor]: """ Get input choices for the decoder and the selective copy weights. The decoder input choices are simply the `last_predictions`, except for target OOV predictions that were copied from source tokens, in which case the prediction will be changed to the COPY symbol in the target namespace. The selective weights are just the probabilities assigned to source tokens that were copied, normalized to sum to 1. If no source tokens were copied, there will be all zeros. Parameters ---------- last_predictions : ``torch.LongTensor`` Shape: `(group_size,)` state : ``Dict[str, torch.Tensor]`` Returns ------- Tuple[torch.LongTensor, torch.Tensor] `input_choices` (shape `(group_size,)`) and `selective_weights` (shape `(group_size, trimmed_source_length)`). """ group_size, trimmed_source_length = state["source_to_target"].size() # This is a mask indicating which last predictions were copied from the # the source AND not in the target vocabulary (OOV). # (group_size,) only_copied_mask = (last_predictions >= self._target_vocab_size).long() # If the last prediction was in the target vocab or OOV but not copied, # we use that as input, otherwise we use the COPY token. # shape: (group_size,) copy_input_choices = only_copied_mask.new_full((group_size,), fill_value=self._copy_index) input_choices = last_predictions * (1 - only_copied_mask) + copy_input_choices * only_copied_mask # In order to get the `selective_weights`, we need to find out which predictions # were copied or copied AND generated, which is the case when a prediction appears # in both the source sentence and the target vocab. But whenever a prediction # is in the target vocab (even if it also appeared in the source sentence), # its index will be the corresponding target vocab index, not its index in # the source sentence offset by the target vocab size. So we first # use `state["source_to_target"]` to get an indicator of every source token # that matches the predicted target token. # shape: (group_size, trimmed_source_length) expanded_last_predictions = last_predictions.unsqueeze(-1).expand(group_size, trimmed_source_length) # shape: (group_size, trimmed_source_length) source_copied_and_generated = (state["source_to_target"] == expanded_last_predictions).long() # In order to get indicators for copied source tokens that are OOV with respect # to the target vocab, we'll make use of `state["source_to_source"]`. # First we adjust predictions relative to the start of the source tokens. # This makes sense because predictions for copied tokens are given by the index of the copied # token in the source sentence, offset by the size of the target vocabulary. # shape: (group_size,) adjusted_predictions = last_predictions - self._target_vocab_size # The adjusted indices for items that were not copied will be negative numbers, # and therefore invalid. So we zero them out. adjusted_predictions = adjusted_predictions * only_copied_mask # shape: (group_size, trimmed_source_length, trimmed_source_length) source_to_source = state["source_to_source"] # Expand adjusted_predictions to match source_to_source shape. # shape: (group_size, trimmed_source_length, trimmed_source_length) adjusted_predictions = adjusted_predictions.unsqueeze(-1)\ .unsqueeze(-1)\ .expand(source_to_source.size()) # The mask will contain indicators for source tokens that were copied # during the last timestep. # shape: (group_size, trimmed_source_length) source_only_copied = source_to_source.gather(-1, adjusted_predictions)[:, :, 0].long() # Since we zero'd-out indices for predictions that were not copied, # we need to zero out all entries of this mask corresponding to those predictions. source_only_copied = source_only_copied * only_copied_mask.\ unsqueeze(-1).\ expand(source_only_copied.size()) # shape: (group_size, trimmed_source_length) mask = source_only_copied | source_copied_and_generated # shape: (group_size, trimmed_source_length) raw_selective_weights = state["copy_probs"] * mask.float() # shape: (group_size, trimmed_source_length) selective_weights = raw_selective_weights / (raw_selective_weights.sum(dim=-1, keepdim=True) + 1e-13) return input_choices, selective_weights def _gather_final_probs(self, generation_probs: torch.Tensor, copy_probs: torch.Tensor, state: Dict[str, torch.Tensor]) -> torch.Tensor: """ Combine copy probabilities with generation probabilities for matching tokens. Parameters ---------- generation_probs : ``torch.Tensor`` Shape: `(group_size, target_vocab_size)` copy_probs : ``torch.Tensor`` Shape: `(group_size, trimmed_source_length)` state : ``Dict[str, torch.Tensor]`` Returns ------- torch.Tensor Shape: `(group_size, target_vocab_size + trimmed_source_length)`. """ _, trimmed_source_length = state["source_to_target"].size() # shape: [(batch_size, *)] modified_probs_list: List[torch.Tensor] = [generation_probs] for i in range(trimmed_source_length): # shape: (group_size,) copy_probs_slice = copy_probs[:, i] # `source_to_target` is a matrix of shape (group_size, trimmed_source_length) # where element (i, j) is the vocab index of the target token that matches the jth # source token in the ith group, if there is one, or the index of the OOV symbol otherwise. # We'll use this to add copy scores to corresponding generation scores. # shape: (group_size,) source_to_target_slice = state["source_to_target"][:, i] # The OOV index in the source_to_target_slice indicates that the source # token is not in the target vocab, so we don't want to add that copy score # to the OOV token. copy_probs_to_add_mask = (source_to_target_slice != self._oov_index).float() copy_probs_to_add = copy_probs_slice * copy_probs_to_add_mask generation_probs.scatter_add_( -1, source_to_target_slice.unsqueeze(-1), copy_probs_to_add.unsqueeze(-1)) # We have to combine copy scores for duplicate source tokens so that # we can find the overall most likely source token. So, if this is the first # occurence of this particular source token, we add the probs from all other # occurences, otherwise we zero it out since it was already accounted for. if i < (trimmed_source_length - 1): # Sum copy scores from future occurences of source token. # shape: (group_size, trimmed_source_length - i) source_future_occurences = state["source_to_source"][:, i, (i+1):] # shape: (group_size, trimmed_source_length - i) future_copy_probs = copy_probs[:, (i+1):] * source_future_occurences # shape: (group_size,) summed_future_copy_probs = future_copy_probs.sum(dim=-1) copy_probs_slice = copy_probs_slice + summed_future_copy_probs if i > 0: # Zero-out copy probs that we have already accounted for. # shape: (group_size, i) source_previous_occurences = state["source_to_source"][:, i, 0:i] # shape: (group_size,) duplicate_mask = (source_previous_occurences.sum(dim=-1) == 0).float() copy_probs_slice = copy_probs_slice * duplicate_mask # Finally, we zero-out copy scores that we added to the generation scores # above so that we don't double-count them. # shape: (group_size,) left_over_copy_probs = copy_probs_slice * (1.0 - copy_probs_to_add_mask) modified_probs_list.append(left_over_copy_probs.unsqueeze(-1)) # shape: (group_size, target_vocab_size + trimmed_source_length) modified_probs = torch.cat(modified_probs_list, dim=-1) return modified_probs def take_search_step(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take step during beam search. This function is what gets passed to the `BeamSearch.search` method. It takes predictions from the last timestep and the current state and outputs the log probabilities assigned to tokens for the next timestep, as well as the updated state. Since we are predicting tokens out of the extended vocab (target vocab + all unique tokens from the source sentence), this is a little more complicated that just making a forward pass through the model. The output log probs will have shape `(group_size, target_vocab_size + trimmed_source_length)` so that each token in the target vocab and source sentence are assigned a probability. Note that copy scores are assigned to each source token based on their position, not unique value. So if a token appears more than once in the source sentence, it will have more than one score. Further, if a source token is also part of the target vocab, its final score will be the sum of the generation and copy scores. Therefore, in order to get the score for all tokens in the extended vocab at this step, we have to combine copy scores for re-occuring source tokens and potentially add them to the generation scores for the matching token in the target vocab, if there is one. So we can break down the final log probs output as the concatenation of two matrices, A: `(group_size, target_vocab_size)`, and B: `(group_size, trimmed_source_length)`. Matrix A contains the sum of the generation score and copy scores (possibly 0) for each target token. Matrix B contains left-over copy scores for source tokens that do NOT appear in the target vocab, with zeros everywhere else. But since a source token may appear more than once in the source sentence, we also have to sum the scores for each appearance of each unique source token. So matrix B actually only has non-zero values at the first occurence of each source token that is not in the target vocab. Parameters ---------- last_predictions : ``torch.Tensor`` Shape: `(group_size,)` state : ``Dict[str, torch.Tensor]`` Contains all state tensors necessary to produce generation and copy scores for next step. Notes ----- `group_size` != `batch_size`. In fact, `group_size` = `batch_size * beam_size`. """ _, trimmed_source_length = state["source_to_target"].size() # Get input to the decoder RNN and the selective weights. `input_choices` # is the result of replacing target OOV tokens in `last_predictions` with the # copy symbol. `selective_weights` consist of the normalized copy probabilities # assigned to the source tokens that were copied. If no tokens were copied, # there will be all zeros. # shape: (group_size,), (group_size, trimmed_source_length) input_choices, selective_weights = self._get_input_and_selective_weights(last_predictions, state) # Update the decoder state by taking a step through the RNN. state = self._decoder_step(input_choices, selective_weights, state) # Get the un-normalized generation scores for each token in the target vocab. # shape: (group_size, target_vocab_size) generation_scores = self._get_generation_scores(state) # Get the un-normalized copy scores for each token in the source sentence, # excluding the start and end tokens. # shape: (group_size, trimmed_source_length) copy_scores = self._get_copy_scores(state) # Concat un-normalized generation and copy scores. # shape: (batch_size, target_vocab_size + trimmed_source_length) all_scores = torch.cat((generation_scores, copy_scores), dim=-1) # shape: (group_size, trimmed_source_length) copy_mask = state["source_mask"][:, 1:-1].float() # shape: (batch_size, target_vocab_size + trimmed_source_length) mask = torch.cat((generation_scores.new_full(generation_scores.size(), 1.0), copy_mask), dim=-1) # Normalize generation and copy scores. # shape: (batch_size, target_vocab_size + trimmed_source_length) probs = util.masked_softmax(all_scores, mask) # shape: (group_size, target_vocab_size), (group_size, trimmed_source_length) generation_probs, copy_probs = probs.split([self._target_vocab_size, trimmed_source_length], dim=-1) # Update copy_probs needed for getting the `selective_weights` at the next timestep. state["copy_probs"] = copy_probs # We now have normalized generation and copy scores, but to produce the final # score for each token in the extended vocab, we have to go through and add # the copy scores to the generation scores of matching target tokens, and sum # the copy scores of duplicate source tokens. # shape: (group_size, target_vocab_size + trimmed_source_length) final_probs = self._gather_final_probs(generation_probs, copy_probs, state) return final_probs.log(), state def _get_predicted_tokens(self, predicted_indices: numpy.ndarray, batch_metadata: List[Any], n_best: int = None) -> List[Union[List[List[str]], List[str]]]: if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() predicted_tokens: List[Union[List[List[str]], List[str]]] = [] for top_k_predictions, metadata in zip(predicted_indices, batch_metadata): batch_predicted_tokens: List[List[str]] = [] for indices in top_k_predictions[:n_best]: tokens: List[str] = [] indices = list(indices) if self._end_index in indices: indices = indices[:indices.index(self._end_index)] for index in indices: if index >= self._target_vocab_size: adjusted_index = index - self._target_vocab_size token = metadata["source_tokens"][adjusted_index] else: token = self.vocab.get_token_from_index(index, self._target_namespace) tokens.append(token) batch_predicted_tokens.append(tokens) if n_best == 1: predicted_tokens.append(batch_predicted_tokens[0]) else: predicted_tokens.append(batch_predicted_tokens) return predicted_tokens @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: """ Finalize predictions. After a beam search, the predicted indices correspond to tokens in the target vocabulary OR tokens in source sentence. Here we gather the actual tokens corresponding to the indices. """ predicted_tokens = self._get_predicted_tokens(output_dict["predictions"], output_dict["metadata"]) output_dict["predicted_tokens"] = predicted_tokens return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._metric and not self.training: all_metrics.update(self._metric.get_metric(reset=reset)) return all_metrics
class BeamSearchTest(AllenNlpTestCase): def setup_method(self): super().setup_method() self.end_index = transition_probabilities.size()[0] - 1 self.beam_search = BeamSearch(self.end_index, max_steps=10, beam_size=3) # This is what the top k should look like for each item in the batch. self.expected_top_k = np.array([[1, 2, 3, 4, 5], [2, 3, 4, 5, 5], [3, 4, 5, 5, 5]]) # This is what the log probs should look like for each item in the batch. self.expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) def _check_results( self, batch_size: int = 5, expected_top_k: np.array = None, expected_log_probs: np.array = None, beam_search: BeamSearch = None, state: Dict[str, torch.Tensor] = None, ) -> None: expected_top_k = expected_top_k if expected_top_k is not None else self.expected_top_k expected_log_probs = (expected_log_probs if expected_log_probs is not None else self.expected_log_probs) state = state or {} beam_search = beam_search or self.beam_search beam_size = beam_search.beam_size initial_predictions = torch.tensor([0] * batch_size) top_k, log_probs = beam_search.search(initial_predictions, state, take_step) # type: ignore # top_k should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(top_k.size())[:-1] == [batch_size, beam_size] np.testing.assert_array_equal(top_k[0].numpy(), expected_top_k) # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(log_probs.size()) == [batch_size, beam_size] np.testing.assert_allclose(log_probs[0].numpy(), expected_log_probs) def test_search(self): self._check_results() def test_finished_state(self): state = {} state["foo"] = torch.tensor([[1, 0, 1], [2, 0, 1], [0, 0, 1], [1, 1, 1], [0, 0, 0]]) # shape: (batch_size, 3) expected_finished_state = {} expected_finished_state["foo"] = np.array([ [1, 0, 1], [1, 0, 1], [1, 0, 1], [2, 0, 1], [2, 0, 1], [2, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], ]) # shape: (batch_size x beam_size, 3) self._check_results(state=state) # check finished state. for key, array in expected_finished_state.items(): np.testing.assert_allclose(state[key].numpy(), array) def test_batch_size_of_one(self): self._check_results(batch_size=1) def test_greedy_search(self): beam_search = BeamSearch(self.end_index, beam_size=1) expected_top_k = np.array([[1, 2, 3, 4, 5]]) expected_log_probs = np.log(np.array([0.4])) self._check_results( expected_top_k=expected_top_k, expected_log_probs=expected_log_probs, beam_search=beam_search, ) def test_early_stopping(self): """ Checks case where beam search will reach `max_steps` before finding end tokens. """ beam_search = BeamSearch(self.end_index, beam_size=3, max_steps=3) expected_top_k = np.array([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) self._check_results( expected_top_k=expected_top_k, expected_log_probs=expected_log_probs, beam_search=beam_search, ) def test_different_per_node_beam_size(self): # per_node_beam_size = 1 beam_search = BeamSearch(self.end_index, beam_size=3, per_node_beam_size=1) self._check_results(beam_search=beam_search) # per_node_beam_size = 2 beam_search = BeamSearch(self.end_index, beam_size=3, per_node_beam_size=2) self._check_results(beam_search=beam_search) def test_catch_bad_config(self): """ If `per_node_beam_size` (which defaults to `beam_size`) is larger than the size of the target vocabulary, `BeamSearch.search` should raise a ConfigurationError. """ beam_search = BeamSearch(self.end_index, beam_size=20) with pytest.raises(ConfigurationError): self._check_results(beam_search=beam_search) def test_warn_for_bad_log_probs(self): # The only valid next step from the initial predictions is the end index. # But with a beam size of 3, the call to `topk` to find the 3 most likely # next beams will result in 2 new beams that are invalid, in that have probability of 0. # The beam search should warn us of this. initial_predictions = torch.LongTensor( [self.end_index - 1, self.end_index - 1]) with pytest.warns(RuntimeWarning, match="Infinite log probabilities"): self.beam_search.search(initial_predictions, {}, take_step) def test_empty_sequences(self): initial_predictions = torch.LongTensor( [self.end_index - 1, self.end_index - 1]) beam_search = BeamSearch(self.end_index, beam_size=1) with pytest.warns(RuntimeWarning, match="Empty sequences predicted"): predictions, log_probs = beam_search.search( initial_predictions, {}, take_step) # predictions hould have shape `(batch_size, beam_size, max_predicted_length)`. assert list(predictions.size()) == [2, 1, 1] # log probs hould have shape `(batch_size, beam_size)`. assert list(log_probs.size()) == [2, 1] assert (predictions == self.end_index).all() assert (log_probs == 0).all()
class CopyNetSeq2Seq(Model): def __init__( self, vocab: Vocabulary, attention: Attention, beam_size: int, max_decoding_steps: int, target_embedding_dim: int = 30, copy_token: str = "@COPY@", source_namespace: str = "bert", target_namespace: str = "target_tokens", tensor_based_metric: Metric = None, token_based_metric: Metric = None, initializer: InitializerApplicator = InitializerApplicator(), ) -> None: super().__init__(vocab) self._source_namespace = source_namespace self._target_namespace = target_namespace self._src_start_index = self.vocab.get_token_index( START_SYMBOL, self._source_namespace) self._src_end_index = self.vocab.get_token_index( END_SYMBOL, self._source_namespace) self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) self._oov_index = self.vocab.get_token_index(self.vocab._oov_token, self._target_namespace) self._pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) self._copy_index = self.vocab.add_token_to_namespace( copy_token, self._target_namespace) self._tensor_based_metric = tensor_based_metric or BLEU( exclude_indices={ self._pad_index, self._end_index, self._start_index }) self._token_based_metric = token_based_metric self._target_vocab_size = self.vocab.get_vocab_size( self._target_namespace) # Encoding modules. bert_token_embedding = PretrainedBertEmbedder('bert-base-uncased', requires_grad=True) self._source_embedder = bert_token_embedding self._encoder = PassThroughEncoder( input_dim=self._source_embedder.get_output_dim()) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. # We arbitrarily set the decoder's input dimension to be the same as the output dimension. self.encoder_output_dim = self._encoder.get_output_dim() self.decoder_output_dim = self.encoder_output_dim self.decoder_input_dim = self.decoder_output_dim target_vocab_size = self.vocab.get_vocab_size(self._target_namespace) # The decoder input will be a function of the embedding of the previous predicted token, # an attended encoder hidden state called the "attentive read", and another # weighted sum of the encoder hidden state called the "selective read". # While the weights for the attentive read are calculated by an `Attention` module, # the weights for the selective read are simply the predicted probabilities # corresponding to each token in the source sentence that matches the target # token from the previous timestep. self._target_embedder = Embedding(target_vocab_size, target_embedding_dim) self._attention = attention self._input_projection_layer = Linear( target_embedding_dim + self.encoder_output_dim * 2, self.decoder_input_dim) # We then run the projected decoder input through an LSTM cell to produce # the next hidden state. self._decoder_cell = LSTMCell(self.decoder_input_dim, self.decoder_output_dim) # We create a "generation" score for each token in the target vocab # with a linear projection of the decoder hidden state. self._output_generation_layer = Linear(self.decoder_output_dim, target_vocab_size) # We create a "copying" score for each source token by applying a non-linearity # (tanh) to a linear projection of the encoded hidden state for that token, # and then taking the dot product of the result with the decoder hidden state. self._output_copying_layer = Linear(self.encoder_output_dim, self.decoder_output_dim) # At prediction time, we'll use a beam search to find the best target sequence. self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) initializer(self) @overrides def forward( self, # type: ignore source_tokens: Dict[str, torch.LongTensor], source_token_ids: torch.Tensor, source_to_target: torch.Tensor, metadata: List[Dict[str, Any]], target_tokens: Dict[str, torch.LongTensor] = None, target_token_ids: torch.Tensor = None, ) -> Dict[str, torch.Tensor]: """ Make foward pass with decoder logic for producing the entire target sequence. Parameters ---------- source_tokens : ``Dict[str, torch.LongTensor]``, required The output of `TextField.as_array()` applied on the source `TextField`. This will be passed through a `TextFieldEmbedder` and then through an encoder. source_token_ids : ``torch.Tensor``, required Tensor containing IDs that indicate which source tokens match each other. Has shape: `(batch_size, trimmed_source_length)`. source_to_target : ``torch.Tensor``, required Tensor containing vocab index of each source token with respect to the target vocab namespace. Shape: `(batch_size, trimmed_source_length)`. metadata : ``List[Dict[str, Any]]``, required Metadata field that contains the original source tokens with key 'source_tokens' and any other meta fields. When 'target_tokens' is also passed, the metadata should also contain the original target tokens with key 'target_tokens'. target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) Output of `Textfield.as_array()` applied on target `TextField`. We assume that the target tokens are also represented as a `TextField` which must contain a "tokens" key that uses single ids. target_token_ids : ``torch.Tensor``, optional (default = None) A tensor of shape `(batch_size, target_sequence_length)` which indicates which tokens in the target sequence match tokens in the source sequence. Returns ------- Dict[str, torch.Tensor] """ state = self._encode(source_tokens) state["source_token_ids"] = source_token_ids state["source_to_target"] = source_to_target if target_tokens: state = self._init_decoder_state(state) output_dict = self._forward_loss(target_tokens, target_token_ids, state) else: output_dict = {} output_dict["metadata"] = metadata if not self.training: state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) if target_tokens: if self._tensor_based_metric is not None: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] # shape: (batch_size, target_sequence_length) gold_tokens = self._gather_extended_gold_tokens( target_tokens["tokens"], source_token_ids, target_token_ids) self._tensor_based_metric(best_predictions, gold_tokens) # type: ignore if self._token_based_metric is not None: predicted_tokens = self._get_predicted_tokens( output_dict["predictions"], metadata, n_best=1) self._token_based_metric( # type: ignore predicted_tokens, [x["target_tokens"] for x in metadata]) return output_dict def _gather_extended_gold_tokens( self, target_tokens: torch.Tensor, source_token_ids: torch.Tensor, target_token_ids: torch.Tensor, ) -> torch.LongTensor: """ Modify the gold target tokens relative to the extended vocabulary. For gold targets that are OOV but were copied from the source, the OOV index will be changed to the index of the first occurence in the source sentence, offset by the size of the target vocabulary. Parameters ---------- target_tokens : ``torch.Tensor`` Shape: `(batch_size, target_sequence_length)`. source_token_ids : ``torch.Tensor`` Shape: `(batch_size, trimmed_source_length)`. target_token_ids : ``torch.Tensor`` Shape: `(batch_size, target_sequence_length)`. Returns ------- torch.Tensor Modified `target_tokens` with OOV indices replaced by offset index of first match in source sentence. """ batch_size, target_sequence_length = target_tokens.size() trimmed_source_length = source_token_ids.size(1) # Only change indices for tokens that were OOV in target vocab but copied from source. # shape: (batch_size, target_sequence_length) oov = target_tokens == self._oov_index # shape: (batch_size, target_sequence_length, trimmed_source_length) expanded_source_token_ids = source_token_ids.unsqueeze(1).expand( batch_size, target_sequence_length, trimmed_source_length) # shape: (batch_size, target_sequence_length, trimmed_source_length) expanded_target_token_ids = target_token_ids.unsqueeze(-1).expand( batch_size, target_sequence_length, trimmed_source_length) # shape: (batch_size, target_sequence_length, trimmed_source_length) matches = expanded_source_token_ids == expanded_target_token_ids # shape: (batch_size, target_sequence_length) copied = matches.sum(-1) > 0 # shape: (batch_size, target_sequence_length) mask = (oov & copied).long() # shape: (batch_size, target_sequence_length) first_match = ((matches.cumsum(-1) == 1) * matches).to( torch.uint8).argmax(-1) # shape: (batch_size, target_sequence_length) new_target_tokens = ( target_tokens * (1 - mask) + (first_match.long() + self._target_vocab_size) * mask) return new_target_tokens def _init_decoder_state( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Initialize the encoded state to be passed to the first decoding time step. """ batch_size, _ = state["source_mask"].size() # Initialize the decoder hidden state with the final output of the encoder, # and the decoder context with zeros. # shape: (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( state["encoder_outputs"], state["source_mask"], self._encoder.is_bidirectional()) # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = final_encoder_output # shape: (batch_size, decoder_output_dim) state["decoder_context"] = state["encoder_outputs"].new_zeros( batch_size, self.decoder_output_dim) return state def _encode( self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Encode source input sentences. """ # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_input = self._source_embedder.forward( source_tokens['bert'], source_tokens['bert-offsets']) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = self._encoder(embedded_input, source_mask) return {"source_mask": source_mask, "encoder_outputs": encoder_outputs} def _decoder_step( self, last_predictions: torch.Tensor, selective_weights: torch.Tensor, state: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs_mask = state["source_mask"].float() # shape: (group_size, target_embedding_dim) embedded_input = self._target_embedder(last_predictions) # shape: (group_size, max_input_sequence_length) attentive_weights = self._attention(state["decoder_hidden"], state["encoder_outputs"], encoder_outputs_mask) # shape: (group_size, encoder_output_dim) attentive_read = util.weighted_sum(state["encoder_outputs"], attentive_weights) # shape: (group_size, encoder_output_dim) selective_read = util.weighted_sum(state["encoder_outputs"][:, 1:-1], selective_weights) # shape: (group_size, target_embedding_dim + encoder_output_dim * 2) decoder_input = torch.cat( (embedded_input, attentive_read, selective_read), -1) # shape: (group_size, decoder_input_dim) projected_decoder_input = self._input_projection_layer(decoder_input) state["decoder_hidden"], state["decoder_context"] = self._decoder_cell( projected_decoder_input, (state["decoder_hidden"], state["decoder_context"])) return state def _get_generation_scores(self, state: Dict[str, torch.Tensor]) -> torch.Tensor: return self._output_generation_layer(state["decoder_hidden"]) def _get_copy_scores(self, state: Dict[str, torch.Tensor]) -> torch.Tensor: # shape: (batch_size, max_input_sequence_length - 2, encoder_output_dim) trimmed_encoder_outputs = state["encoder_outputs"][:, 1:-1] # shape: (batch_size, max_input_sequence_length - 2, decoder_output_dim) copy_projection = self._output_copying_layer(trimmed_encoder_outputs) # shape: (batch_size, max_input_sequence_length - 2, decoder_output_dim) copy_projection = torch.tanh(copy_projection) # shape: (batch_size, max_input_sequence_length - 2) copy_scores = copy_projection.bmm( state["decoder_hidden"].unsqueeze(-1)).squeeze(-1) return copy_scores def _get_ll_contrib( self, generation_scores: torch.Tensor, generation_scores_mask: torch.Tensor, copy_scores: torch.Tensor, target_tokens: torch.Tensor, target_to_source: torch.Tensor, copy_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Get the log-likelihood contribution from a single timestep. Parameters ---------- generation_scores : ``torch.Tensor`` Shape: `(batch_size, target_vocab_size)` generation_scores_mask : ``torch.Tensor`` Shape: `(batch_size, target_vocab_size)`. This is just a tensor of 1's. copy_scores : ``torch.Tensor`` Shape: `(batch_size, trimmed_source_length)` target_tokens : ``torch.Tensor`` Shape: `(batch_size,)` target_to_source : ``torch.Tensor`` Shape: `(batch_size, trimmed_source_length)` copy_mask : ``torch.Tensor`` Shape: `(batch_size, trimmed_source_length)` Returns ------- Tuple[torch.Tensor, torch.Tensor] Shape: `(batch_size,), (batch_size, max_input_sequence_length)` """ _, target_size = generation_scores.size() # The point of this mask is to just mask out all source token scores # that just represent padding. We apply the mask to the concatenation # of the generation scores and the copy scores to normalize the scores # correctly during the softmax. # shape: (batch_size, target_vocab_size + trimmed_source_length) mask = torch.cat((generation_scores_mask, copy_mask), dim=-1) # shape: (batch_size, target_vocab_size + trimmed_source_length) all_scores = torch.cat((generation_scores, copy_scores), dim=-1) # Normalize generation and copy scores. # shape: (batch_size, target_vocab_size + trimmed_source_length) log_probs = util.masked_log_softmax(all_scores, mask) # Calculate the log probability (`copy_log_probs`) for each token in the source sentence # that matches the current target token. We use the sum of these copy probabilities # for matching tokens in the source sentence to get the total probability # for the target token. We also need to normalize the individual copy probabilities # to create `selective_weights`, which are used in the next timestep to create # a selective read state. # shape: (batch_size, trimmed_source_length) copy_log_probs = log_probs[:, target_size:] + ( target_to_source.float() + 1e-45).log() # Since `log_probs[:, target_size]` gives us the raw copy log probabilities, # we use a non-log softmax to get the normalized non-log copy probabilities. selective_weights = util.masked_softmax(log_probs[:, target_size:], target_to_source) # This mask ensures that item in the batch has a non-zero generation probabilities # for this timestep only when the gold target token is not OOV or there are no # matching tokens in the source sentence. # shape: (batch_size, 1) gen_mask = ((target_tokens != self._oov_index) | (target_to_source.sum(-1) == 0)).float() log_gen_mask = (gen_mask + 1e-45).log().unsqueeze(-1) # Now we get the generation score for the gold target token. # shape: (batch_size, 1) generation_log_probs = log_probs.gather( 1, target_tokens.unsqueeze(1)) + log_gen_mask # ... and add the copy score to get the step log likelihood. # shape: (batch_size, 1 + trimmed_source_length) combined_gen_and_copy = torch.cat( (generation_log_probs, copy_log_probs), dim=-1) # shape: (batch_size,) step_log_likelihood = util.logsumexp(combined_gen_and_copy) return step_log_likelihood, selective_weights def _forward_loss( self, target_tokens: Dict[str, torch.LongTensor], target_token_ids: torch.Tensor, state: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: """ Calculate the loss against gold targets. """ batch_size, target_sequence_length = target_tokens["tokens"].size() # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 # We use this to fill in the copy index when the previous input was copied. # shape: (batch_size,) copy_input_choices = source_mask.new_full((batch_size, ), fill_value=self._copy_index) # shape: (batch_size, trimmed_source_length) copy_mask = source_mask[:, 1:-1].float() # We need to keep track of the probabilities assigned to tokens in the source # sentence that were copied during the previous timestep, since we use # those probabilities as weights when calculating the "selective read". # shape: (batch_size, trimmed_source_length) selective_weights = state["decoder_hidden"].new_zeros(copy_mask.size()) # Indicates which tokens in the source sentence match the current target token. # shape: (batch_size, trimmed_source_length) target_to_source = state["source_token_ids"].new_zeros( copy_mask.size()) # This is just a tensor of ones which we use repeatedly in `self._get_ll_contrib`, # so we create it once here to avoid doing it over-and-over. generation_scores_mask = state["decoder_hidden"].new_full( (batch_size, self._target_vocab_size), fill_value=1.0) step_log_likelihoods = [] for timestep in range(num_decoding_steps): # shape: (batch_size,) input_choices = target_tokens["tokens"][:, timestep] # If the previous target token was copied, we use the special copy token. # But the end target token will always be THE end token, so we know # it was not copied. if timestep < num_decoding_steps - 1: # Get mask tensor indicating which instances were copied. # shape: (batch_size,) copied = ((input_choices == self._oov_index) & (target_to_source.sum(-1) > 0)).long() # shape: (batch_size,) input_choices = input_choices * ( 1 - copied) + copy_input_choices * copied # shape: (batch_size, trimmed_source_length) target_to_source = state[ "source_token_ids"] == target_token_ids[:, timestep + 1].unsqueeze(-1) # Update the decoder state by taking a step through the RNN. state = self._decoder_step(input_choices, selective_weights, state) # Get generation scores for each token in the target vocab. # shape: (batch_size, target_vocab_size) generation_scores = self._get_generation_scores(state) # Get copy scores for each token in the source sentence, excluding the start # and end tokens. # shape: (batch_size, trimmed_source_length) copy_scores = self._get_copy_scores(state) # shape: (batch_size,) step_target_tokens = target_tokens["tokens"][:, timestep + 1] step_log_likelihood, selective_weights = self._get_ll_contrib( generation_scores, generation_scores_mask, copy_scores, step_target_tokens, target_to_source, copy_mask, ) step_log_likelihoods.append(step_log_likelihood.unsqueeze(1)) # Gather step log-likelihoods. # shape: (batch_size, num_decoding_steps = target_sequence_length - 1) log_likelihoods = torch.cat(step_log_likelihoods, 1) # Get target mask to exclude likelihood contributions from timesteps after # the END token. # shape: (batch_size, target_sequence_length) target_mask = util.get_text_field_mask(target_tokens) # The first timestep is just the START token, which is not included in the likelihoods. # shape: (batch_size, num_decoding_steps) target_mask = target_mask[:, 1:].float() # Sum of step log-likelihoods. # shape: (batch_size,) log_likelihood = (log_likelihoods * target_mask).sum(dim=-1) # The loss is the negative log-likelihood, averaged over the batch. loss = -log_likelihood.sum() / batch_size return {"loss": loss} def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size, source_length = state["source_mask"].size() trimmed_source_length = source_length - 2 # Initialize the copy scores to zero. state["copy_log_probs"] = (state["decoder_hidden"].new_zeros( (batch_size, trimmed_source_length)) + 1e-45).log() # shape: (batch_size,) start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_search_step) return { "predicted_log_probs": log_probabilities, "predictions": all_top_k_predictions } def _get_input_and_selective_weights( self, last_predictions: torch.LongTensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.LongTensor, torch.Tensor]: """ Get input choices for the decoder and the selective copy weights. The decoder input choices are simply the `last_predictions`, except for target OOV predictions that were copied from source tokens, in which case the prediction will be changed to the COPY symbol in the target namespace. The selective weights are just the probabilities assigned to source tokens that were copied, normalized to sum to 1. If no source tokens were copied, there will be all zeros. Parameters ---------- last_predictions : ``torch.LongTensor`` Shape: `(group_size,)` state : ``Dict[str, torch.Tensor]`` Returns ------- Tuple[torch.LongTensor, torch.Tensor] `input_choices` (shape `(group_size,)`) and `selective_weights` (shape `(group_size, trimmed_source_length)`). """ group_size, trimmed_source_length = state["source_to_target"].size() # This is a mask indicating which last predictions were copied from the # the source AND not in the target vocabulary (OOV). # (group_size,) only_copied_mask = (last_predictions >= self._target_vocab_size).long() # If the last prediction was in the target vocab or OOV but not copied, # we use that as input, otherwise we use the COPY token. # shape: (group_size,) copy_input_choices = only_copied_mask.new_full( (group_size, ), fill_value=self._copy_index) input_choices = (last_predictions * (1 - only_copied_mask) + copy_input_choices * only_copied_mask) # In order to get the `selective_weights`, we need to find out which predictions # were copied or copied AND generated, which is the case when a prediction appears # in both the source sentence and the target vocab. But whenever a prediction # is in the target vocab (even if it also appeared in the source sentence), # its index will be the corresponding target vocab index, not its index in # the source sentence offset by the target vocab size. So we first # use `state["source_to_target"]` to get an indicator of every source token # that matches the predicted target token. # shape: (group_size, trimmed_source_length) expanded_last_predictions = last_predictions.unsqueeze(-1).expand( group_size, trimmed_source_length) # shape: (group_size, trimmed_source_length) source_copied_and_generated = ( state["source_to_target"] == expanded_last_predictions).long() # In order to get indicators for copied source tokens that are OOV with respect # to the target vocab, we'll make use of `state["source_token_ids"]`. # First we adjust predictions relative to the start of the source tokens. # This makes sense because predictions for copied tokens are given by the index of the copied # token in the source sentence, offset by the size of the target vocabulary. # shape: (group_size,) adjusted_predictions = last_predictions - self._target_vocab_size # The adjusted indices for items that were not copied will be negative numbers, # and therefore invalid. So we zero them out. adjusted_predictions = adjusted_predictions * only_copied_mask # shape: (group_size, trimmed_source_length) source_token_ids = state["source_token_ids"] # shape: (group_size, trimmed_source_length) adjusted_prediction_ids = source_token_ids.gather( -1, adjusted_predictions.unsqueeze(-1)) # This mask will contain indicators for source tokens that were copied # during the last timestep. # shape: (group_size, trimmed_source_length) source_only_copied = ( source_token_ids == adjusted_prediction_ids).long() # Since we zero'd-out indices for predictions that were not copied, # we need to zero out all entries of this mask corresponding to those predictions. source_only_copied = source_only_copied * only_copied_mask.unsqueeze( -1) # shape: (group_size, trimmed_source_length) mask = source_only_copied | source_copied_and_generated # shape: (group_size, trimmed_source_length) selective_weights = util.masked_softmax(state["copy_log_probs"], mask) return input_choices, selective_weights def _gather_final_log_probs( self, generation_log_probs: torch.Tensor, copy_log_probs: torch.Tensor, state: Dict[str, torch.Tensor], ) -> torch.Tensor: """ Combine copy probabilities with generation probabilities for matching tokens. Parameters ---------- generation_log_probs : ``torch.Tensor`` Shape: `(group_size, target_vocab_size)` copy_log_probs : ``torch.Tensor`` Shape: `(group_size, trimmed_source_length)` state : ``Dict[str, torch.Tensor]`` Returns ------- torch.Tensor Shape: `(group_size, target_vocab_size + trimmed_source_length)`. """ _, trimmed_source_length = state["source_to_target"].size() source_token_ids = state["source_token_ids"] # shape: [(batch_size, *)] modified_log_probs_list: List[torch.Tensor] = [] for i in range(trimmed_source_length): # shape: (group_size,) copy_log_probs_slice = copy_log_probs[:, i] # `source_to_target` is a matrix of shape (group_size, trimmed_source_length) # where element (i, j) is the vocab index of the target token that matches the jth # source token in the ith group, if there is one, or the index of the OOV symbol otherwise. # We'll use this to add copy scores to corresponding generation scores. # shape: (group_size,) source_to_target_slice = state["source_to_target"][:, i] # The OOV index in the source_to_target_slice indicates that the source # token is not in the target vocab, so we don't want to add that copy score # to the OOV token. copy_log_probs_to_add_mask = (source_to_target_slice != self._oov_index).float() copy_log_probs_to_add = ( copy_log_probs_slice + (copy_log_probs_to_add_mask + 1e-45).log()) # shape: (batch_size, 1) copy_log_probs_to_add = copy_log_probs_to_add.unsqueeze(-1) # shape: (batch_size, 1) selected_generation_log_probs = generation_log_probs.gather( 1, source_to_target_slice.unsqueeze(-1)) combined_scores = util.logsumexp( torch.cat( (selected_generation_log_probs, copy_log_probs_to_add), dim=1)) generation_log_probs = generation_log_probs.scatter( -1, source_to_target_slice.unsqueeze(-1), combined_scores.unsqueeze(-1)) # We have to combine copy scores for duplicate source tokens so that # we can find the overall most likely source token. So, if this is the first # occurence of this particular source token, we add the log_probs from all other # occurences, otherwise we zero it out since it was already accounted for. if i < (trimmed_source_length - 1): # Sum copy scores from future occurences of source token. # shape: (group_size, trimmed_source_length - i) source_future_occurences = (source_token_ids[:, ( i + 1):] == source_token_ids[:, i].unsqueeze(-1)).float() # noqa # shape: (group_size, trimmed_source_length - i) future_copy_log_probs = ( copy_log_probs[:, (i + 1):] + (source_future_occurences + 1e-45).log()) # shape: (group_size, 1 + trimmed_source_length - i) combined = torch.cat((copy_log_probs_slice.unsqueeze(-1), future_copy_log_probs), dim=-1) # shape: (group_size,) copy_log_probs_slice = util.logsumexp(combined) if i > 0: # Remove copy log_probs that we have already accounted for. # shape: (group_size, i) source_previous_occurences = source_token_ids[:, 0: i] == source_token_ids[:, i].unsqueeze( -1) # shape: (group_size,) duplicate_mask = (source_previous_occurences.sum( dim=-1) == 0).float() copy_log_probs_slice = copy_log_probs_slice + (duplicate_mask + 1e-45).log() # Finally, we zero-out copy scores that we added to the generation scores # above so that we don't double-count them. # shape: (group_size,) left_over_copy_log_probs = ( copy_log_probs_slice + (1.0 - copy_log_probs_to_add_mask + 1e-45).log()) modified_log_probs_list.append( left_over_copy_log_probs.unsqueeze(-1)) modified_log_probs_list.insert(0, generation_log_probs) # shape: (group_size, target_vocab_size + trimmed_source_length) modified_log_probs = torch.cat(modified_log_probs_list, dim=-1) return modified_log_probs def take_search_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take step during beam search. This function is what gets passed to the `BeamSearch.search` method. It takes predictions from the last timestep and the current state and outputs the log probabilities assigned to tokens for the next timestep, as well as the updated state. Since we are predicting tokens out of the extended vocab (target vocab + all unique tokens from the source sentence), this is a little more complicated that just making a forward pass through the model. The output log probs will have shape `(group_size, target_vocab_size + trimmed_source_length)` so that each token in the target vocab and source sentence are assigned a probability. Note that copy scores are assigned to each source token based on their position, not unique value. So if a token appears more than once in the source sentence, it will have more than one score. Further, if a source token is also part of the target vocab, its final score will be the sum of the generation and copy scores. Therefore, in order to get the score for all tokens in the extended vocab at this step, we have to combine copy scores for re-occuring source tokens and potentially add them to the generation scores for the matching token in the target vocab, if there is one. So we can break down the final log probs output as the concatenation of two matrices, A: `(group_size, target_vocab_size)`, and B: `(group_size, trimmed_source_length)`. Matrix A contains the sum of the generation score and copy scores (possibly 0) for each target token. Matrix B contains left-over copy scores for source tokens that do NOT appear in the target vocab, with zeros everywhere else. But since a source token may appear more than once in the source sentence, we also have to sum the scores for each appearance of each unique source token. So matrix B actually only has non-zero values at the first occurence of each source token that is not in the target vocab. Parameters ---------- last_predictions : ``torch.Tensor`` Shape: `(group_size,)` state : ``Dict[str, torch.Tensor]`` Contains all state tensors necessary to produce generation and copy scores for next step. Notes ----- `group_size` != `batch_size`. In fact, `group_size` = `batch_size * beam_size`. """ _, trimmed_source_length = state["source_to_target"].size() # Get input to the decoder RNN and the selective weights. `input_choices` # is the result of replacing target OOV tokens in `last_predictions` with the # copy symbol. `selective_weights` consist of the normalized copy probabilities # assigned to the source tokens that were copied. If no tokens were copied, # there will be all zeros. # shape: (group_size,), (group_size, trimmed_source_length) input_choices, selective_weights = self._get_input_and_selective_weights( last_predictions, state) # Update the decoder state by taking a step through the RNN. state = self._decoder_step(input_choices, selective_weights, state) # Get the un-normalized generation scores for each token in the target vocab. # shape: (group_size, target_vocab_size) generation_scores = self._get_generation_scores(state) # Get the un-normalized copy scores for each token in the source sentence, # excluding the start and end tokens. # shape: (group_size, trimmed_source_length) copy_scores = self._get_copy_scores(state) # Concat un-normalized generation and copy scores. # shape: (batch_size, target_vocab_size + trimmed_source_length) all_scores = torch.cat((generation_scores, copy_scores), dim=-1) # shape: (group_size, trimmed_source_length) copy_mask = state["source_mask"][:, 1:-1].float() # shape: (batch_size, target_vocab_size + trimmed_source_length) mask = torch.cat((generation_scores.new_full(generation_scores.size(), 1.0), copy_mask), dim=-1) # Normalize generation and copy scores. # shape: (batch_size, target_vocab_size + trimmed_source_length) log_probs = util.masked_log_softmax(all_scores, mask) # shape: (group_size, target_vocab_size), (group_size, trimmed_source_length) generation_log_probs, copy_log_probs = log_probs.split( [self._target_vocab_size, trimmed_source_length], dim=-1) # Update copy_probs needed for getting the `selective_weights` at the next timestep. state["copy_log_probs"] = copy_log_probs # We now have normalized generation and copy scores, but to produce the final # score for each token in the extended vocab, we have to go through and add # the copy scores to the generation scores of matching target tokens, and sum # the copy scores of duplicate source tokens. # shape: (group_size, target_vocab_size + trimmed_source_length) final_log_probs = self._gather_final_log_probs(generation_log_probs, copy_log_probs, state) return final_log_probs, state def _get_predicted_tokens( self, predicted_indices: Union[torch.Tensor, np.ndarray], batch_metadata: List[Any], n_best: int = None, ) -> List[Union[List[List[str]], List[str]]]: """ Convert predicted indices into tokens. If `n_best = 1`, the result type will be `List[List[str]]`. Otherwise the result type will be `List[List[List[str]]]`. """ if not isinstance(predicted_indices, np.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() predicted_tokens: List[Union[List[List[str]], List[str]]] = [] for top_k_predictions, metadata in zip(predicted_indices, batch_metadata): batch_predicted_tokens: List[List[str]] = [] for indices in top_k_predictions[:n_best]: tokens: List[str] = [] indices = list(indices) if self._end_index in indices: indices = indices[:indices.index(self._end_index)] for index in indices: if index >= self._target_vocab_size: adjusted_index = index - self._target_vocab_size token = metadata["source_tokens"][adjusted_index] else: token = self.vocab.get_token_from_index( index, self._target_namespace) tokens.append(token) batch_predicted_tokens.append(tokens) if n_best == 1: predicted_tokens.append(batch_predicted_tokens[0]) else: predicted_tokens.append(batch_predicted_tokens) return predicted_tokens @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: """ Finalize predictions. After a beam search, the predicted indices correspond to tokens in the target vocabulary OR tokens in source sentence. Here we gather the actual tokens corresponding to the indices. """ predicted_tokens = self._get_predicted_tokens( output_dict["predictions"], output_dict["metadata"]) output_dict["predicted_tokens"] = predicted_tokens return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if not self.training: if self._tensor_based_metric is not None: all_metrics.update( self._tensor_based_metric.get_metric( reset=reset) # type: ignore ) if self._token_based_metric is not None: all_metrics.update( self._token_based_metric.get_metric( reset=reset)) # type: ignore return all_metrics
class PointerGeneratorNetwork(Model): """ Based on https://arxiv.org/pdf/1704.04368.pdf """ def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, attention: Attention, max_decoding_steps: int, beam_size: int = None, target_namespace: str = "tokens", target_embedding_dim: int = None, scheduled_sampling_ratio: float = 0., projection_dim: int = None, use_coverage: bool = False, coverage_shift: float = 0., coverage_loss_weight: float = None, embed_attn_to_output: bool = False) -> None: super(PointerGeneratorNetwork, self).__init__(vocab) self._target_namespace = target_namespace self._start_index = self.vocab.get_token_index(START_SYMBOL, target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, target_namespace) self._unk_index = self.vocab.get_token_index(DEFAULT_OOV_TOKEN, target_namespace) self._vocab_size = self.vocab.get_vocab_size(target_namespace) assert self._vocab_size > 2, \ "Target vocabulary is empty. Make sure 'target_namespace' option of the model is correct." # Encoder self._source_embedder = source_embedder self._encoder = encoder self._encoder_output_dim = self._encoder.get_output_dim() # Decoder self._target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim( ) self._num_classes = self.vocab.get_vocab_size(target_namespace) self._target_embedder = Embedding(self._num_classes, self._target_embedding_dim) self._decoder_input_dim = self._encoder_output_dim + self._target_embedding_dim self._decoder_output_dim = self._encoder_output_dim self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) self._projection_dim = projection_dim or self._source_embedder.get_output_dim( ) hidden_projection_dim = self._decoder_output_dim if not embed_attn_to_output else self._decoder_output_dim * 2 self._hidden_projection_layer = Linear(hidden_projection_dim, self._projection_dim) self._output_projection_layer = Linear(self._projection_dim, self._num_classes) self._p_gen_layer = Linear( self._decoder_output_dim * 3 + self._decoder_input_dim, 1) self._attention = attention self._use_coverage = use_coverage self._coverage_loss_weight = coverage_loss_weight self._eps = 1e-31 self._embed_attn_to_output = embed_attn_to_output self._coverage_shift = coverage_shift # Metrics self._p_gen_sum = 0.0 self._p_gen_iterations = 0 self._coverage_loss_sum = 0.0 self._coverage_iterations = 0 # Decoding self._scheduled_sampling_ratio = scheduled_sampling_ratio self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size or 1) def forward(self, source_tokens: Dict[str, torch.LongTensor], source_token_ids: torch.Tensor, source_to_target: torch.LongTensor, target_tokens: Dict[str, torch.LongTensor] = None, target_token_ids: torch.Tensor = None, metadata=None) -> Dict[str, torch.Tensor]: state = self._encode(source_tokens) target_tokens_tensor = target_tokens["tokens"].long( ) if target_tokens else None extra_zeros, modified_source_tokens, modified_target_tokens = self._prepare( source_to_target, source_token_ids, target_tokens_tensor, target_token_ids) state["tokens"] = modified_source_tokens state["extra_zeros"] = extra_zeros output_dict = {} if target_tokens: state["target_tokens"] = modified_target_tokens state = self._init_decoder_state(state) output_dict = self._forward_loop(state, target_tokens) output_dict["metadata"] = metadata output_dict["source_to_target"] = source_to_target if not self.training: state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) return output_dict def _prepare(self, source_tokens: torch.LongTensor, source_token_ids: torch.Tensor, target_tokens: torch.LongTensor = None, target_token_ids: torch.Tensor = None): batch_size = source_tokens.size(0) source_max_length = source_tokens.size(1) tokens = source_tokens token_ids = source_token_ids.long() # Concat target tokens if exist if target_tokens is not None: tokens = torch.cat((tokens, target_tokens), 1) token_ids = torch.cat((token_ids, target_token_ids.long()), 1) is_unk = torch.eq(tokens, self._unk_index).long() # Create tensor with ids of unknown tokens only. # Those ids are batch-local. unk_only = token_ids * is_unk # Recalculate batch-local ids to range [1, count_of_unique_unk_tokens]. # All known tokens have zero id. unk_token_nums = token_ids.new_zeros((batch_size, token_ids.size(1))) for i in range(batch_size): unique = torch.unique(unk_only[i, :], return_inverse=True, sorted=True)[1] unk_token_nums[i, :] = unique # Replace DEFAULT_OOV_TOKEN id with new batch-local ids starting from vocab_size # For example, if vocabulary size is 50000, the first unique unknown token will have 50000 index, # the second will have 50001 index and so on. tokens = tokens - tokens * is_unk + (self._vocab_size - 1) * is_unk + unk_token_nums modified_target_tokens = None modified_source_tokens = tokens if target_tokens is not None: # Remove target unknown tokens that do not exist in source tokens max_source_num = torch.max(tokens[:, :source_max_length], dim=1)[0] vocab_size = max_source_num.new_full((1, ), self._vocab_size - 1) max_source_num = torch.max(max_source_num, other=vocab_size).unsqueeze(1).expand( (-1, tokens.size(1))) unk_target_tokens_mask = torch.gt(tokens, max_source_num).long() tokens = tokens - tokens * unk_target_tokens_mask + self._unk_index * unk_target_tokens_mask modified_target_tokens = tokens[:, source_max_length:] modified_source_tokens = tokens[:, :source_max_length] # Count unique unknown source tokens to create enough zeros for final distribution source_unk_count = torch.max(unk_token_nums[:, :source_max_length]) extra_zeros = tokens.new_zeros((batch_size, source_unk_count), dtype=torch.float32) return extra_zeros, modified_source_tokens, modified_target_tokens def _encode( self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_input = self._source_embedder.forward(source_tokens) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = self._encoder.forward(embedded_input, source_mask) return { "source_mask": source_mask, "encoder_outputs": encoder_outputs, } def _init_decoder_state( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( state["encoder_outputs"], state["source_mask"], self._encoder.is_bidirectional()) # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = final_encoder_output encoder_outputs = state["encoder_outputs"] state["decoder_context"] = encoder_outputs.new_zeros( batch_size, self._decoder_output_dim) if self._embed_attn_to_output: state["attn_context"] = encoder_outputs.new_zeros( encoder_outputs.size(0), encoder_outputs.size(2)) if self._use_coverage: state["coverage"] = encoder_outputs.new_zeros( batch_size, encoder_outputs.size(1)) return state def _prepare_output_projections( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (group_size, decoder_output_dim) decoder_context = state["decoder_context"] # shape: (group_size, decoder_output_dim) attn_context = state.get("attn_context", None) is_unk = (last_predictions >= self._vocab_size).long() last_predictions_fixed = last_predictions - last_predictions * is_unk + self._unk_index * is_unk embedded_input = self._target_embedder(last_predictions_fixed) coverage = state.get("coverage", None) def get_attention_context(decoder_hidden_inner): if coverage is None: attention_scores = self._attention(decoder_hidden_inner, encoder_outputs, source_mask) else: attention_scores = self._attention(decoder_hidden_inner, encoder_outputs, source_mask, coverage) attention_context = util.weighted_sum(encoder_outputs, attention_scores) return attention_scores, attention_context if not self._embed_attn_to_output: attn_scores, attn_context = get_attention_context(decoder_hidden) decoder_input = torch.cat((attn_context, embedded_input), -1) decoder_hidden, decoder_context = self._decoder_cell( decoder_input, (decoder_hidden, decoder_context)) projection = self._hidden_projection_layer(decoder_hidden) else: decoder_input = torch.cat((attn_context, embedded_input), -1) decoder_hidden, decoder_context = self._decoder_cell( decoder_input, (decoder_hidden, decoder_context)) attn_scores, attn_context = get_attention_context(decoder_hidden) projection = self._hidden_projection_layer( torch.cat((attn_context, decoder_hidden), -1)) output_projections = self._output_projection_layer(projection) if self._use_coverage: state["coverage"] = coverage + attn_scores state["decoder_input"] = decoder_input state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context state["attn_scores"] = attn_scores state["attn_context"] = attn_context return output_projections, state def _get_final_dist(self, state: Dict[str, torch.Tensor], output_projections): attn_dist = state["attn_scores"] tokens = state["tokens"] extra_zeros = state["extra_zeros"] attn_context = state["attn_context"] decoder_input = state["decoder_input"] decoder_hidden = state["decoder_hidden"] decoder_context = state["decoder_context"] decoder_state = torch.cat((decoder_hidden, decoder_context), 1) p_gen = self._p_gen_layer( torch.cat((attn_context, decoder_state, decoder_input), 1)) p_gen = torch.sigmoid(p_gen) self._p_gen_sum += torch.mean(p_gen).item() self._p_gen_iterations += 1 vocab_dist = F.softmax(output_projections, dim=-1) vocab_dist = vocab_dist * p_gen attn_dist = attn_dist * (1.0 - p_gen) if extra_zeros.size(1) != 0: vocab_dist = torch.cat((vocab_dist, extra_zeros), 1) final_dist = vocab_dist.scatter_add(1, tokens, attn_dist) normalization_factor = final_dist.sum(1, keepdim=True) final_dist = final_dist / normalization_factor return final_dist def _forward_loop( self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size(0) num_decoding_steps = self._max_decoding_steps if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] _, target_sequence_length = targets.size() num_decoding_steps = target_sequence_length - 1 if self._use_coverage: coverage_loss = source_mask.new_zeros(1, dtype=torch.float32) last_predictions = state["tokens"].new_full( (batch_size, ), fill_value=self._start_index) step_proba: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand( 1).item() < self._scheduled_sampling_ratio: input_choices = last_predictions elif not target_tokens: input_choices = last_predictions else: input_choices = targets[:, timestep] if self._use_coverage: old_coverage = state["coverage"] output_projections, state = self._prepare_output_projections( input_choices, state) final_dist = self._get_final_dist(state, output_projections) step_proba.append(final_dist) last_predictions = torch.max(final_dist, 1)[1] step_predictions.append(last_predictions.unsqueeze(1)) if self._use_coverage: step_coverage_loss = torch.sum( torch.min(state["attn_scores"], old_coverage), 1) coverage_loss = coverage_loss + step_coverage_loss # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) output_dict = {"predictions": predictions} if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) num_classes = step_proba[0].size(1) proba = step_proba[0].new_zeros( (batch_size, num_classes, len(step_proba))) for i, p in enumerate(step_proba): proba[:, :, i] = p loss = self._get_loss(proba, state["target_tokens"], self._eps) if self._use_coverage: coverage_loss = torch.mean(coverage_loss / num_decoding_steps) self._coverage_loss_sum += coverage_loss.item() self._coverage_iterations += 1 modified_coverage_loss = relu( coverage_loss - self._coverage_shift) + self._coverage_shift - 1.0 loss = loss + self._coverage_loss_weight * modified_coverage_loss output_dict["loss"] = loss return output_dict @staticmethod def _get_loss(proba: torch.LongTensor, targets: torch.LongTensor, eps: float) -> torch.Tensor: targets = targets[:, 1:] proba = torch.log(proba + eps) loss = torch.nn.NLLLoss(ignore_index=0)(proba, targets) return loss def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size()[0] start_predictions = state["tokens"].new_full( (batch_size, ), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections( last_predictions, state) final_dist = self._get_final_dist(state, output_projections) log_probabilities = torch.log(final_dist + self._eps) return log_probabilities, state def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, np.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] all_meta = output_dict["metadata"] all_source_to_target = output_dict["source_to_target"] for (indices, metadata), source_to_target in zip( zip(predicted_indices, all_meta), all_source_to_target): all_predicted_tokens.append( self._decode_sample(indices, metadata, source_to_target)) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _decode_sample(self, indices, metadata, source_to_target): predicted_tokens = [] # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] # Get all unknown tokens from source original_source_tokens = metadata["source_tokens"] unk_tokens = list() for i, token_vocab_index in enumerate(source_to_target): if token_vocab_index != self._unk_index: continue token = original_source_tokens[i] if token in unk_tokens: continue unk_tokens.append(token) for token_vocab_index in indices: if token_vocab_index < self._vocab_size: token = self.vocab.get_token_from_index( token_vocab_index, namespace=self._target_namespace) else: unk_number = token_vocab_index - self._vocab_size token = unk_tokens[unk_number] predicted_tokens.append(token) return predicted_tokens def get_metrics(self, reset: bool = False) -> Dict[str, float]: if not self._use_coverage: return {} avg_coverage_loss = self._coverage_loss_sum / self._coverage_iterations if self._coverage_iterations != 0 else 0.0 avg_p_gen = self._p_gen_sum / self._p_gen_iterations if self._p_gen_iterations != 0 else 0.0 metrics = {"coverage_loss": avg_coverage_loss, "p_gen": avg_p_gen} if reset: self._p_gen_sum = 0.0 self._p_gen_iterations = 0 self._coverage_loss_sum = 0.0 self._coverage_iterations = 0 return metrics
class SequenceTransformer(Model): def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, target_embedder: Embedding, encoder: Seq2SeqEncoder, max_decoding_steps: int, decoding_dim: int, feedforward_hidden_dim: int, num_layers: int, num_attention_heads: int, use_positional_encoding: bool = True, positional_encoding_max_steps: int = 5000, dropout_prob: float = 0.1, residual_dropout_prob: float = 0.2, attention_dropout_prob: float = 0.2, beam_size: int = 1, target_namespace: str = "tokens", label_smoothing_ratio: Optional[float] = None, initializer: Optional[InitializerApplicator] = None) -> None: super(SequenceTransformer, self).__init__(vocab) self._target_namespace = target_namespace self._label_smoothing_ratio = label_smoothing_ratio self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) self._token_based_metric = TokenSequenceAccuracy() # Beam Search self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Encoder self._encoder = encoder # Vocabulary and embedder self._source_embedder = source_embedder self._target_embedder = target_embedder target_vocab_size = self.vocab.get_vocab_size(self._target_namespace) assert target_vocab_size == self._target_embedder.num_embeddings target_embedding_dim = self._target_embedder.get_output_dim() self._decoding_dim = decoding_dim # Sequence Decoder Features self._output_projection_layer = Linear(self._decoding_dim, target_vocab_size) self._decoder = Decoder( num_layers=num_layers, decoding_dim=decoding_dim, target_embedding_dim=target_embedding_dim, feedforward_hidden_dim=feedforward_hidden_dim, num_attention_heads=num_attention_heads, use_positional_encoding=use_positional_encoding, positional_encoding_max_steps=positional_encoding_max_steps, dropout_prob=dropout_prob, residual_dropout_prob=residual_dropout_prob, attention_dropout_prob=attention_dropout_prob) # Parameter checks and cleanup if self._target_embedder.get_output_dim( ) != self._decoder.target_embedding_dim: raise ConfigurationError( "Target Embedder output_dim doesn't match decoder module's input." ) # if self._encoder.get_output_dim() != self._decoder.get_output_dim(): raise ConfigurationError( f"Encoder output dimension {self._encoder.get_output_dim()} should be" f" equal to decoder dimension {self._self_attention.get_output_dim()}." ) if initializer: initializer(self) # Print the model print(self) def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. # Parameters last_predictions : `torch.Tensor` A tensor of shape `(group_size,)`, which gives the indices of the predictions during the last time step. state : `Dict[str, torch.Tensor]` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape `(group_size, *)`, where `*` can be any other number of dimensions. # Returns Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of `(log_probabilities, updated_state)`, where `log_probabilities` is a tensor of shape `(group_size, num_classes)` containing the predicted log probability of each class for the next step, for each item in the group, while `updated_state` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though `group_size` is not necessarily equal to `batch_size`, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes) output_projections, state = self._decoder_step(last_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def forward( self, # type: ignore source_tokens: Dict[str, torch.LongTensor], metadata: List[Dict[str, Any]], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: """ Make forward pass with decoder logic for producing the entire target sequence. Parameters ---------- source_tokens : ``Dict[str, torch.LongTensor]`` The output of `TextField.as_array()` applied on the source `TextField`. This will be passed through a `TextFieldEmbedder` and then through an encoder. metadata: List[Dict[str, Any]] Additional information for prediction target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) Output of `Textfield.as_array()` applied on target `TextField`. We assume that the target tokens are also represented as a `TextField`. Returns ------- Dict[str, torch.Tensor] """ state = self._encode(source_tokens) if target_tokens: # state = self._decoder.init_decoder_state(state) # The `_forward_loop` decodes the input sequence and computes the loss during training # and validation. output_dict = self._forward_loop(state, target_tokens) else: output_dict = {} if not self.training: # state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) if target_tokens: # shape: (batch_size, max_predicted_sequence_length) predicted_tokens = self.decode(output_dict)["predicted_tokens"] self._token_based_metric( predicted_tokens, [x["target_tokens"] for x in metadata]) return output_dict @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens # type: ignore return output_dict def _encode( self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Make forward pass on the encoder. # Parameters source_tokens : `Dict[str, torch.Tensor]` The output of `TextField.as_array()` applied on the source `TextField`. This will be passed through a `TextFieldEmbedder` and then through an encoder. # Returns Dict[str, torch.Tensor] Map consisting of the key `source_mask` with the mask over the `source_tokens` text field, and the key `encoder_outputs` with the output tensor from forward pass on the encoder. """ # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_input = self._source_embedder(source_tokens) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = self._encoder(embedded_input, source_mask) return {"source_mask": source_mask, "encoder_outputs": encoder_outputs} def _forward_loop( self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] _, target_sequence_length = targets.size() # Prepare embeddings for targets. They will be used as gold embeddings during decoder training # shape: (batch_size, max_target_sequence_length, embedding_dim) target_embedding = self._target_embedder(targets) # shape: (batch_size, max_target_batch_sequence_length) target_mask = util.get_text_field_mask(target_tokens) _, decoder_output = self._decoder( previous_state=state, previous_steps_predictions=target_embedding[:, :-1, :], encoder_outputs=encoder_outputs, source_mask=source_mask, previous_steps_mask=target_mask[:, :-1]) # shape: (group_size, max_target_sequence_length, num_classes) logits = self._output_projection_layer(decoder_output).type( torch.FloatTensor) # Compute loss. loss = self._get_loss(logits, targets, target_mask) output_dict = {"loss": loss} return output_dict def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Prepare inputs for the beam search, does beam search and returns beam search results. """ batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _decoder_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, steps_count, decoder_output_dim) previous_steps_predictions = state.get("previous_steps_predictions") # shape: (batch_size, 1, target_embedding_dim) last_predictions_embeddings = self._target_embedder( last_predictions).unsqueeze(1) if previous_steps_predictions is None or previous_steps_predictions.shape[ -1] == 0: # There is no previous steps, except for start vectors in `last_predictions` # shape: (group_size, 1, target_embedding_dim) previous_steps_predictions = last_predictions_embeddings else: # shape: (group_size, steps_count, target_embedding_dim) previous_steps_predictions = torch.cat( [previous_steps_predictions, last_predictions_embeddings], 1) decoder_state, decoder_output = self._decoder( previous_state=state, encoder_outputs=encoder_outputs, source_mask=source_mask, previous_steps_predictions=previous_steps_predictions, ) state["previous_steps_predictions"] = previous_steps_predictions # Update state with new decoder state, override previous state state.update(decoder_state) if self._decoder.decodes_parallel: decoder_output = decoder_output[:, -1, :] # shape: (group_size, num_classes) output_projections = self._output_projection_layer(decoder_output) return output_projections, state def _get_loss(self, logits: torch.FloatTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous().to(logits.device) # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous().to(logits.device) return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if not self.training: all_metrics.update( self._token_based_metric.get_metric(reset=reset)) return all_metrics
class Event2Mind(Model): """ This ``Event2Mind`` class is a :class:`Model` which takes an event sequence, encodes it, and then uses the encoded representation to decode several mental state sequences. It is based on `the paper by Rashkin et al. <https://www.semanticscholar.org/paper/Event2Mind/b89f8a9b2192a8f2018eead6b135ed30a1f2144d>`_ Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (``tokens``) or the target tokens can have a different namespace, in which case it needs to be specified as ``target_namespace``. source_embedder : ``TextFieldEmbedder``, required Embedder for source side sequences. embedding_dropout: float, required The amount of dropout to apply after the source tokens have been embedded. encoder : ``Seq2VecEncoder``, required The encoder of the "encoder/decoder" model. max_decoding_steps : int, required Length of decoded sequences. beam_size : int, optional (default = 10) The width of the beam search. target_names: ``List[str]``, optional, (default = ['xintent', 'xreact', 'oreact']) Names of the target fields matching those in the ``Instance`` objects. target_namespace : str, optional (default = 'tokens') If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. target_embedding_dim : int, optional (default = source_embedding_dim) You can specify an embedding dimensionality for the target side. If not, we'll use the same value as the source embedder's. """ def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, embedding_dropout: float, encoder: Seq2VecEncoder, max_decoding_steps: int, beam_size: int = 10, target_names: List[str] = None, target_namespace: str = "tokens", target_embedding_dim: int = None) -> None: super().__init__(vocab) target_names = target_names or ["xintent", "xreact", "oreact"] # Note: The original tweaks the embeddings for "personx" to be the mean # across the embeddings for "he", "she", "him" and "her". Similarly for # "personx's" and so forth. We could consider that here as a well. self._source_embedder = source_embedder self._embedding_dropout = nn.Dropout(embedding_dropout) self._encoder = encoder self._max_decoding_steps = max_decoding_steps self._target_namespace = target_namespace # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) # Warning: The different decoders share a vocabulary! This may be # counterintuitive, but consider the case of xreact and oreact. A # reaction of "happy" could easily apply to both the subject of the # event and others. This could become less appropriate as more decoders # are added. num_classes = self.vocab.get_vocab_size(self._target_namespace) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with that of the final hidden states of the encoder. self._decoder_output_dim = self._encoder.get_output_dim() target_embedding_dim = target_embedding_dim or self._source_embedder.get_output_dim() self._states = ModuleDict() for name in target_names: self._states[name] = StateDecoder( num_classes, target_embedding_dim, self._decoder_output_dim ) self._beam_search = BeamSearch( self._end_index, beam_size=beam_size, max_steps=max_decoding_steps ) def _update_recall(self, all_top_k_predictions: torch.Tensor, target_tokens: Dict[str, torch.LongTensor], target_recall: UnigramRecall) -> None: targets = target_tokens["tokens"] target_mask = get_text_field_mask(target_tokens) # See comment in _get_loss. # TODO(brendanr): Do we need contiguous here? relevant_targets = targets[:, 1:].contiguous() relevant_mask = target_mask[:, 1:].contiguous() target_recall( all_top_k_predictions, relevant_targets, relevant_mask, self._end_index ) def _get_num_decoding_steps(self, target_tokens: Optional[Dict[str, torch.LongTensor]]) -> int: if target_tokens: targets = target_tokens["tokens"] target_sequence_length = targets.size()[1] # The last input from the target is either padding or the end # symbol. Either way, we don't have to process it. (To be clear, # we do still output and compare against the end symbol, but there # is no need to take the end symbol as input to the decoder.) return target_sequence_length - 1 else: return self._max_decoding_steps @overrides def forward(self, # type: ignore source: Dict[str, torch.LongTensor], **target_tokens: Dict[str, Dict[str, torch.LongTensor]]) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing the target sequences. Parameters ---------- source : ``Dict[str, torch.LongTensor]`` The output of ``TextField.as_array()`` applied on the source ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. target_tokens : ``Dict[str, Dict[str, torch.LongTensor]]``: Dictionary from name to output of ``Textfield.as_array()`` applied on target ``TextField``. We assume that the target tokens are also represented as a ``TextField``. """ # (batch_size, input_sequence_length, embedding_dim) embedded_input = self._embedding_dropout(self._source_embedder(source)) source_mask = get_text_field_mask(source) # (batch_size, encoder_output_dim) final_encoder_output = self._encoder(embedded_input, source_mask) output_dict = {} # Perform greedy search so we can get the loss. if target_tokens: if target_tokens.keys() != self._states.keys(): target_only = target_tokens.keys() - self._states.keys() states_only = self._states.keys() - target_tokens.keys() raise Exception("Mismatch between target_tokens and self._states. Keys in " + f"targets only: {target_only} Keys in states only: {states_only}") total_loss = 0 for name, state in self._states.items(): loss = self.greedy_search( final_encoder_output=final_encoder_output, target_tokens=target_tokens[name], target_embedder=state.embedder, decoder_cell=state.decoder_cell, output_projection_layer=state.output_projection_layer ) total_loss += loss output_dict[f"{name}_loss"] = loss # Use mean loss (instead of the sum of the losses) to be comparable to the paper. output_dict["loss"] = total_loss / len(self._states) # Perform beam search to obtain the predictions. if not self.training: batch_size = final_encoder_output.size()[0] for name, state in self._states.items(): start_predictions = final_encoder_output.new_full( (batch_size,), fill_value=self._start_index, dtype=torch.long) start_state = {"decoder_hidden": final_encoder_output} # (batch_size, 10, num_decoding_steps) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, start_state, state.take_step) if target_tokens: self._update_recall(all_top_k_predictions, target_tokens[name], state.recall) output_dict[f"{name}_top_k_predictions"] = all_top_k_predictions output_dict[f"{name}_top_k_log_probabilities"] = log_probabilities return output_dict def greedy_search(self, final_encoder_output: torch.LongTensor, target_tokens: Dict[str, torch.LongTensor], target_embedder: Embedding, decoder_cell: GRUCell, output_projection_layer: Linear) -> torch.FloatTensor: """ Greedily produces a sequence using the provided ``decoder_cell``. Returns the cross entropy between this sequence and ``target_tokens``. Parameters ---------- final_encoder_output : ``torch.LongTensor``, required Vector produced by ``self._encoder``. target_tokens : ``Dict[str, torch.LongTensor]``, required The output of ``TextField.as_array()`` applied on some target ``TextField``. target_embedder : ``Embedding``, required Used to embed the target tokens. decoder_cell: ``GRUCell``, required The recurrent cell used at each time step. output_projection_layer: ``Linear``, required Linear layer mapping to the desired number of classes. """ num_decoding_steps = self._get_num_decoding_steps(target_tokens) targets = target_tokens["tokens"] decoder_hidden = final_encoder_output step_logits = [] for timestep in range(num_decoding_steps): # See https://github.com/allenai/allennlp/issues/1134. input_choices = targets[:, timestep] decoder_input = target_embedder(input_choices) decoder_hidden = decoder_cell(decoder_input, decoder_hidden) # (batch_size, num_classes) output_projections = output_projection_layer(decoder_hidden) # list of (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) target_mask = get_text_field_mask(target_tokens) return self._get_loss(logits, targets, target_mask) def greedy_predict(self, final_encoder_output: torch.LongTensor, target_embedder: Embedding, decoder_cell: GRUCell, output_projection_layer: Linear) -> torch.Tensor: """ Greedily produces a sequence using the provided ``decoder_cell``. Returns the predicted sequence. Parameters ---------- final_encoder_output : ``torch.LongTensor``, required Vector produced by ``self._encoder``. target_embedder : ``Embedding``, required Used to embed the target tokens. decoder_cell: ``GRUCell``, required The recurrent cell used at each time step. output_projection_layer: ``Linear``, required Linear layer mapping to the desired number of classes. """ num_decoding_steps = self._max_decoding_steps decoder_hidden = final_encoder_output batch_size = final_encoder_output.size()[0] predictions = [final_encoder_output.new_full( (batch_size,), fill_value=self._start_index, dtype=torch.long )] for _ in range(num_decoding_steps): input_choices = predictions[-1] decoder_input = target_embedder(input_choices) decoder_hidden = decoder_cell(decoder_input, decoder_hidden) # (batch_size, num_classes) output_projections = output_projection_layer(decoder_hidden) class_probabilities = F.softmax(output_projections, dim=-1) _, predicted_classes = torch.max(class_probabilities, 1) predictions.append(predicted_classes) all_predictions = torch.cat([ps.unsqueeze(1) for ps in predictions], 1) # Drop start symbol and return. return all_predictions[:, 1:] @staticmethod def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.FloatTensor: """ Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ relevant_targets = targets[:, 1:].contiguous() # (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() # (batch_size, num_decoding_steps) loss = sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) return loss def decode_all(self, predicted_indices: torch.Tensor) -> List[List[str]]: if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [self.vocab.get_token_from_index(x, namespace=self._target_namespace) for x in indices] all_predicted_tokens.append(predicted_tokens) return all_predicted_tokens @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, List[List[str]]]: """ This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds fields for the tokens to the ``output_dict``. """ for name in self._states: top_k_predicted_indices = output_dict[f"{name}_top_k_predictions"][0] output_dict[f"{name}_top_k_predicted_tokens"] = [self.decode_all(top_k_predicted_indices)] return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics = {} # Recall@10 needs beam search which doesn't happen during training. if not self.training: for name, state in self._states.items(): all_metrics[name] = state.recall.get_metric(reset=reset) return all_metrics