def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, contextualizer: Seq2SeqEncoder, hparams: Dict, ) -> None: super().__init__(vocab) self.text_field_embedder = text_field_embedder self.contextualizer = contextualizer self.bidirectional = contextualizer.is_bidirectional() if self.bidirectional: self.forward_dim = contextualizer.get_output_dim() // 2 else: self.forward_dim = contextualizer.get_output_dim() dropout = hparams["dropout"] if dropout: self.dropout = torch.nn.Dropout(dropout) else: self.dropout = lambda x: x self.hidden2chord = torch.nn.Sequential( torch.nn.Linear(self.forward_dim, hparams["fc_hidden_dim"]), torch.nn.ReLU(True), torch.nn.Linear(hparams["fc_hidden_dim"], vocab.get_vocab_size()), ) self.perplexity = PerplexityCustom() self.accuracy = CategoricalAccuracy() self.real_loss = Average() self.similarity_matrix = hparams["similarity_matrix"] self.training_mode = hparams["training_mode"] self.T_initial = hparams["T_initial"] self.T = self.T_initial self.decay_rate = hparams["decay_rate"] self.batches_per_epoch = hparams["batches_per_epoch"] self.epoch = 0 self.batch_counter = 0
def __init__( self, pretrained_model: str, discriminative_loss_weight: float = 0, vocab: Vocabulary = Vocabulary(), softmax_over_vocab: bool = False, initializer: InitializerApplicator = InitializerApplicator() ) -> None: super(GNLI, self).__init__(vocab) # Check the arguments of `__init__()`. assert pretrained_model in ['bart.large'] assert discriminative_loss_weight >= 0 and discriminative_loss_weight <= 1 # Load in BART and extend the embeddings layer by three for the label embeddings. self._bart = torch.hub.load('pytorch/fairseq', pretrained_model).model self._extend_embeddings() # Ignore padding indices when calculating generative loss. assert self._bart.encoder.padding_idx == 1 self._generative_loss_fn = torch.nn.CrossEntropyLoss( ignore_index=self._bart.encoder.padding_idx) self._discriminative_loss_fn = torch.nn.NLLLoss() self._discriminative_loss_weight = discriminative_loss_weight self._softmax_over_vocab = softmax_over_vocab if self._softmax_over_vocab: self.effective_vocab_size = self.vocab_size else: self.effective_vocab_size = self.vocab_size + self.label_size self.metrics = { 'accuracy': CategoricalAccuracy(), 'disc_loss': Average(), 'gen_loss': Average() } initializer(self) number_params = sum([ numpy.prod(p.size()) for p in list(self.parameters()) if p.requires_grad ]) logger.info('Number of trainable model parameters: %d', number_params)
def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, sim_text_field_embedder: TextFieldEmbedder, loss_weights: Dict, sim_class_weights: List, pretrained_sim_path: str = None, use_scenario_encoding: bool = True, sim_pretraining: bool = False, dropout: float = 0.2, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(BertQA, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder if use_scenario_encoding: self._sim_text_field_embedder = sim_text_field_embedder self.loss_weights = loss_weights self.sim_class_weights = sim_class_weights self.use_scenario_encoding = use_scenario_encoding self.sim_pretraining = sim_pretraining if self.sim_pretraining and not self.use_scenario_encoding: raise ValueError( "When pretraining Scenario Interpretation Module, you should use it." ) embedding_dim = self._text_field_embedder.get_output_dim() self._action_predictor = torch.nn.Linear(embedding_dim, 4) self._sim_token_label_predictor = torch.nn.Linear(embedding_dim, 4) self._span_predictor = torch.nn.Linear(embedding_dim, 2) self._action_accuracy = CategoricalAccuracy() self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() self._span_loss_metric = Average() self._action_loss_metric = Average() self._sim_loss_metric = Average() self._sim_yes_f1 = F1Measure(2) self._sim_no_f1 = F1Measure(3) if use_scenario_encoding and pretrained_sim_path is not None: logger.info("Loading pretrained model..") self.load_state_dict(torch.load(pretrained_sim_path)) for param in self._sim_text_field_embedder.parameters(): param.requires_grad = False if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x initializer(self)
class Cpm(Model): """ The ``Cpm`` applies a "contextualizing" ``Seq2SeqEncoder`` to uncontextualized embeddings, using a ``torch.nn.functional.kl_div`` module to compute the language modeling loss. If bidirectional is True, the language model is trained to predict the next and previous tokens for each token in the input. In this case, the contextualizer must be bidirectional. If bidirectional is False, the language model is trained to only predict the next token for each token in the input; the contextualizer should also be unidirectional. If your language model is bidirectional, it is IMPORTANT that your bidirectional ``Seq2SeqEncoder`` contextualizer does not do any "peeking ahead". That is, for its forward direction it should only consider embeddings at previous timesteps, and for its backward direction only embeddings at subsequent timesteps. Similarly, if your language model is unidirectional, the unidirectional contextualizer should only consider embeddings at previous timesteps. If this condition is not met, your language model is cheating. Parameters ---------- vocab: ``Vocabulary`` text_field_embedder: ``TextFieldEmbedder`` Used to embed the indexed tokens we get in ``forward``. contextualizer: ``Seq2SeqEncoder`` Used to "contextualize" the embeddings. As described above, this encoder must not cheat by peeking ahead. dropout: ``float``, optional (default: None) If specified, dropout is applied to the contextualized embeddings before computation of the softmax. The contextualized embeddings themselves are returned without dropout. bidirectional: ``bool``, optional (default: False) Train a bidirectional language model, where the contextualizer is used to predict the next and previous token for each input token. This must match the bidirectionality of the contextualizer. """ def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, contextualizer: Seq2SeqEncoder, hparams: Dict, ) -> None: super().__init__(vocab) self.text_field_embedder = text_field_embedder self.contextualizer = contextualizer self.bidirectional = contextualizer.is_bidirectional() if self.bidirectional: self.forward_dim = contextualizer.get_output_dim() // 2 else: self.forward_dim = contextualizer.get_output_dim() dropout = hparams["dropout"] if dropout: self.dropout = torch.nn.Dropout(dropout) else: self.dropout = lambda x: x self.hidden2chord = torch.nn.Sequential( torch.nn.Linear(self.forward_dim, hparams["fc_hidden_dim"]), torch.nn.ReLU(True), torch.nn.Linear(hparams["fc_hidden_dim"], vocab.get_vocab_size()), ) self.perplexity = PerplexityCustom() self.accuracy = CategoricalAccuracy() self.real_loss = Average() self.similarity_matrix = hparams["similarity_matrix"] self.training_mode = hparams["training_mode"] self.T_initial = hparams["T_initial"] self.T = self.T_initial self.decay_rate = hparams["decay_rate"] self.batches_per_epoch = hparams["batches_per_epoch"] self.epoch = 0 self.batch_counter = 0 def num_layers(self) -> int: """ Returns the depth of this LM. That is, how many layers the contextualizer has plus one for the non-contextual layer. """ if hasattr(self.contextualizer, "num_layers"): return self.contextualizer.num_layers + 1 else: raise NotImplementedError( f"Contextualizer of type {type(self.contextualizer)} " + "does not report how many layers it has.") def loss_helper(self, direction_embeddings: torch.Tensor, direction_targets: torch.Tensor): mask = direction_targets > 0 # we need to subtract 1 to undo the padding id since the softmax # does not include a padding dimension # shape (batch_size * timesteps, ) non_masked_targets = direction_targets.masked_select(mask) # shape (batch_size * timesteps, embedding_dim) non_masked_embeddings = direction_embeddings.masked_select( mask.unsqueeze(-1)).view(-1, self.forward_dim) # note: need to return average loss across forward and backward # directions, but total sum loss across all batches. # Assuming batches include full sentences, forward and backward # directions have the same number of samples, so sum up loss # here then divide by 2 just below probs = torch.nn.functional.log_softmax( self.hidden2chord(non_masked_embeddings), dim=-1) real_loss = torch.nn.functional.nll_loss(probs, non_masked_targets, reduction="sum") # transform targets into probability distributions using Embedding # then compute loss using torch.nn.functional.kl_div if self.training: if self.training_mode == TM_ONE_HOT: train_loss = real_loss elif self.training_mode == TM_NO: target_distributions = self.similarity_matrix( non_masked_targets) train_loss = torch.nn.functional.kl_div(probs, target_distributions, reduction="sum") elif self.training_mode == TM_FIXED or self.training_mode == TM_DECREASED: target_distributions = self.similarity_matrix( non_masked_targets) target_distributions = torch.nn.functional.softmax( target_distributions / self.T, dim=1) train_loss = torch.nn.functional.kl_div(probs, target_distributions, reduction="sum") else: raise ValueError("Unknown training mode: {}".format( self.training_mode)) else: train_loss = real_loss return train_loss, real_loss @overrides def forward( self, input_tokens: Dict[str, torch.LongTensor], forward_output_tokens: Dict[str, torch.LongTensor], backward_output_tokens: Dict[str, torch.LongTensor] = None, ) -> Dict[str, torch.Tensor]: """ Computes the averaged forward (and backward, if language model is bidirectional) LM loss from the batch. Returns ------- Dict with keys: ``'loss'``: ``torch.Tensor`` forward negative log likelihood, or the average of forward/backward if language model is bidirectional ``'forward_loss'``: ``torch.Tensor`` forward direction negative log likelihood ``'backward_loss'``: ``torch.Tensor`` or ``None`` backward direction negative log likelihood. If language model is not bidirectional, this is ``None``. ``'contextual_embeddings'``: ``Union[torch.Tensor, List[torch.Tensor]]`` (batch_size, timesteps, embed_dim) tensor of top layer contextual representations or list of all layers. No dropout applied. ``'noncontextual_token_embeddings'``: ``torch.Tensor`` (batch_size, timesteps, token_embed_dim) tensor of bottom layer noncontextual representations ``'mask'``: ``torch.Tensor`` (batch_size, timesteps) mask for the embeddings """ self.batch_counter += 1 if self.batch_counter % self.batches_per_epoch == 0: self.epoch += 1 if self.training_mode == TM_DECREASED: self.T *= 1 / (1 + self.decay_rate * self.epoch) if self.T < 1e-20: self.T = 1e-20 mask = get_text_field_mask(input_tokens) # shape (batch_size, timesteps, embedding_size) embeddings = self.text_field_embedder(input_tokens) contextual_embeddings = self.contextualizer(embeddings, mask) contextual_embeddings_with_dropout = self.dropout( contextual_embeddings) if self.bidirectional: forward_embeddings, backward_embeddings = contextual_embeddings_with_dropout.chunk( 2, -1) backward_logits = self.hidden2chord(backward_embeddings) else: forward_embeddings = contextual_embeddings_with_dropout backward_logits = None forward_logits = self.hidden2chord(forward_embeddings) forward_targets = forward_output_tokens.get("tokens") if self.bidirectional: backward_targets = backward_output_tokens.get("tokens") # compute loss forward_loss, forward_real_loss = self.loss_helper( forward_embeddings, forward_targets) if self.bidirectional: backward_loss, backward_real_loss = self.loss_helper( backward_embeddings, backward_targets) else: backward_loss, backward_real_loss = None, None return_dict = {} num_targets = torch.sum((forward_targets > 0).long()) if num_targets > 0: if self.bidirectional: average_loss = (0.5 * (forward_loss + backward_loss) / num_targets.float()) average_real_loss = (0.5 * (forward_real_loss + backward_real_loss) / num_targets.float()) else: average_loss = forward_loss / num_targets.float() average_real_loss = forward_real_loss / num_targets.float() else: average_loss = torch.tensor(0.0).to(forward_targets.device) average_real_loss = torch.tensor(0.0).to(forward_targets.device) self.perplexity(average_real_loss) self.accuracy(forward_logits, forward_targets, mask) self.real_loss(average_real_loss) return_dict.update({"loss": average_loss}) return_dict.update({ # Note: These embeddings do not have dropout applied. "contextual_embeddings": contextual_embeddings, "noncontextual_token_embeddings": embeddings, "forward_logits": forward_logits, "backward_logits": backward_logits, "mask": mask, }) return return_dict def get_metrics(self, reset: bool = False): return { "perplexity": self.perplexity.get_metric(reset=reset), "accuracy": self.accuracy.get_metric(reset=reset), "real_loss": float(self.real_loss.get_metric(reset=reset)), }
def __init__(self, vocab: Vocabulary, dataset_reader: DatasetReader, source_embedder: TextFieldEmbedder, lang2_namespace: str = "tokens", use_bleu: bool = True) -> None: super().__init__(vocab) self._lang1_namespace = lang2_namespace # TODO: DO NOT HARDCODE IT self._lang2_namespace = lang2_namespace # TODO: do not hardcore this self._backtranslation_src_langs = ["en", "ru"] self._coeff_denoising = 1 self._coeff_backtranslation = 1 self._coeff_translation = 1 self._label_smoothing = 0.1 self._pad_index_lang1 = vocab.get_token_index(DEFAULT_PADDING_TOKEN, self._lang1_namespace) self._oov_index_lang1 = vocab.get_token_index(DEFAULT_OOV_TOKEN, self._lang1_namespace) self._end_index_lang1 = self.vocab.get_token_index( END_SYMBOL, self._lang1_namespace) self._pad_index_lang2 = vocab.get_token_index(DEFAULT_PADDING_TOKEN, self._lang2_namespace) self._oov_index_lang2 = vocab.get_token_index(DEFAULT_OOV_TOKEN, self._lang2_namespace) self._end_index_lang2 = self.vocab.get_token_index( END_SYMBOL, self._lang2_namespace) self._reader = dataset_reader self._langs_list = self._reader._langs_list self._ae_steps = self._reader._ae_steps self._bt_steps = self._reader._bt_steps self._para_steps = self._reader._para_steps if use_bleu: self._bleu = Average() else: self._bleu = None args = ArgsStub() transformer_iwslt_de_en(args) # build encoder if not hasattr(args, 'max_source_positions'): args.max_source_positions = 1024 if not hasattr(args, 'max_target_positions'): args.max_target_positions = 1024 # Dense embedding of source vocab tokens. self._source_embedder = source_embedder # Dense embedding of vocab words in the target space. num_tokens_lang1 = self.vocab.get_vocab_size(self._lang1_namespace) num_tokens_lang2 = self.vocab.get_vocab_size(self._lang2_namespace) args.share_decoder_input_output_embed = False # TODO implement shared embeddings lang1_dict = DictStub(num_tokens=num_tokens_lang1, pad=self._pad_index_lang1, unk=self._oov_index_lang1, eos=self._end_index_lang1) lang2_dict = DictStub(num_tokens=num_tokens_lang2, pad=self._pad_index_lang2, unk=self._oov_index_lang2, eos=self._end_index_lang2) # instantiate fairseq classes emb_golden_tokens = FairseqEmbedding(num_tokens_lang2, args.decoder_embed_dim, self._pad_index_lang2) self._encoder = TransformerEncoder(args, lang1_dict, self._source_embedder) self._decoder = TransformerDecoder(args, lang2_dict, emb_golden_tokens) self._model = TransformerModel(self._encoder, self._decoder) # TODO: do not hardcode max_len_b and beam size self._sequence_generator_greedy = FairseqBeamSearchWrapper( SequenceGenerator(tgt_dict=lang2_dict, beam_size=1, max_len_b=20)) self._sequence_generator_beam = FairseqBeamSearchWrapper( SequenceGenerator(tgt_dict=lang2_dict, beam_size=7, max_len_b=20))
class UnsupervisedTranslation(Model): """ This ``SimpleSeq2Seq`` class is a :class:`Model` which takes a sequence, encodes it, and then uses the encoded representations to decode another sequence. You can use this as the basis for a neural machine translation system, an abstractive summarization system, or any other common seq2seq problem. The model here is simple, but should be a decent starting place for implementing recent models for these tasks. Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (`tokens`) or the target tokens can have a different namespace, in which case it needs to be specified as `target_namespace`. source_embedder : ``TextFieldEmbedder``, required Embedder for source side sequences encoder : ``Seq2SeqEncoder``, required The encoder of the "encoder/decoder" model max_decoding_steps : ``int`` Maximum length of decoded sequences. target_namespace : ``str``, optional (default = 'target_tokens') If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. target_embedding_dim : ``int``, optional (default = source_embedding_dim) You can specify an embedding dimensionality for the target side. If not, we'll use the same value as the source embedder's. attention : ``Attention``, optional (default = None) If you want to use attention to get a dynamic summary of the encoder outputs at each step of decoding, this is the function used to compute similarity between the decoder hidden state and encoder outputs. attention_function: ``SimilarityFunction``, optional (default = None) This is if you want to use the legacy implementation of attention. This will be deprecated since it consumes more memory than the specialized attention modules. beam_size : ``int``, optional (default = None) Width of the beam for beam search. If not specified, greedy decoding is used. scheduled_sampling_ratio : ``float``, optional (default = 0.) At each timestep during training, we sample a random number between 0 and 1, and if it is not less than this value, we use the ground truth labels for the whole batch. Else, we use the predictions from the previous time step for the whole batch. If this value is 0.0 (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not using target side ground truth labels. See the following paper for more information: `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al., 2015 <https://arxiv.org/abs/1506.03099>`_. use_bleu : ``bool``, optional (default = True) If True, the BLEU metric will be calculated during validation. """ def __init__(self, vocab: Vocabulary, dataset_reader: DatasetReader, source_embedder: TextFieldEmbedder, lang2_namespace: str = "tokens", use_bleu: bool = True) -> None: super().__init__(vocab) self._lang1_namespace = lang2_namespace # TODO: DO NOT HARDCODE IT self._lang2_namespace = lang2_namespace # TODO: do not hardcore this self._backtranslation_src_langs = ["en", "ru"] self._coeff_denoising = 1 self._coeff_backtranslation = 1 self._coeff_translation = 1 self._label_smoothing = 0.1 self._pad_index_lang1 = vocab.get_token_index(DEFAULT_PADDING_TOKEN, self._lang1_namespace) self._oov_index_lang1 = vocab.get_token_index(DEFAULT_OOV_TOKEN, self._lang1_namespace) self._end_index_lang1 = self.vocab.get_token_index( END_SYMBOL, self._lang1_namespace) self._pad_index_lang2 = vocab.get_token_index(DEFAULT_PADDING_TOKEN, self._lang2_namespace) self._oov_index_lang2 = vocab.get_token_index(DEFAULT_OOV_TOKEN, self._lang2_namespace) self._end_index_lang2 = self.vocab.get_token_index( END_SYMBOL, self._lang2_namespace) self._reader = dataset_reader self._langs_list = self._reader._langs_list self._ae_steps = self._reader._ae_steps self._bt_steps = self._reader._bt_steps self._para_steps = self._reader._para_steps if use_bleu: self._bleu = Average() else: self._bleu = None args = ArgsStub() transformer_iwslt_de_en(args) # build encoder if not hasattr(args, 'max_source_positions'): args.max_source_positions = 1024 if not hasattr(args, 'max_target_positions'): args.max_target_positions = 1024 # Dense embedding of source vocab tokens. self._source_embedder = source_embedder # Dense embedding of vocab words in the target space. num_tokens_lang1 = self.vocab.get_vocab_size(self._lang1_namespace) num_tokens_lang2 = self.vocab.get_vocab_size(self._lang2_namespace) args.share_decoder_input_output_embed = False # TODO implement shared embeddings lang1_dict = DictStub(num_tokens=num_tokens_lang1, pad=self._pad_index_lang1, unk=self._oov_index_lang1, eos=self._end_index_lang1) lang2_dict = DictStub(num_tokens=num_tokens_lang2, pad=self._pad_index_lang2, unk=self._oov_index_lang2, eos=self._end_index_lang2) # instantiate fairseq classes emb_golden_tokens = FairseqEmbedding(num_tokens_lang2, args.decoder_embed_dim, self._pad_index_lang2) self._encoder = TransformerEncoder(args, lang1_dict, self._source_embedder) self._decoder = TransformerDecoder(args, lang2_dict, emb_golden_tokens) self._model = TransformerModel(self._encoder, self._decoder) # TODO: do not hardcode max_len_b and beam size self._sequence_generator_greedy = FairseqBeamSearchWrapper( SequenceGenerator(tgt_dict=lang2_dict, beam_size=1, max_len_b=20)) self._sequence_generator_beam = FairseqBeamSearchWrapper( SequenceGenerator(tgt_dict=lang2_dict, beam_size=7, max_len_b=20)) @overrides def forward( self, # type: ignore lang_pair: List[str], lang1_tokens: Dict[str, torch.LongTensor] = None, lang1_golden: Dict[str, torch.LongTensor] = None, lang2_tokens: Dict[str, torch.LongTensor] = None, lang2_golden: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ """ # detect training mode and what kind of task we need to compute if lang2_tokens is None and lang1_tokens is None: raise ConfigurationError( "source_tokens and target_tokens can not both be None") mode_training = self.training mode_validation = not self.training and lang2_tokens is not None # change 'target_tokens' condition mode_prediction = lang2_tokens is None # change 'target_tokens' condition lang_src, lang_tgt = lang_pair[0].split('-') if mode_training: # task types task_translation = False task_denoising = False task_backtranslation = False if lang_src == 'xx': task_backtranslation = True elif lang_src == lang_tgt: task_denoising = True elif lang_src != lang_tgt: task_translation = True else: raise ConfigurationError("All tasks are false") output_dict = {} if mode_training: if task_translation: loss = self._forward_seq2seq(lang_pair, lang1_tokens, lang2_tokens, lang2_golden) if self._bleu: predicted_indices = self._sequence_generator_beam.generate( [self._model], lang1_tokens, self._get_true_pad_mask(lang1_tokens), self._end_index_lang2) predicted_strings = self._indices_to_strings( predicted_indices) golden_strings = self._indices_to_strings( lang2_tokens["tokens"]) golden_strings = self._remove_pad_eos(golden_strings) # print(golden_strings, predicted_strings) self._bleu(corpus_bleu(golden_strings, predicted_strings)) elif task_denoising: # might need to split it into two blocks for interlingua loss loss = self._forward_seq2seq(lang_pair, lang1_tokens, lang2_tokens, lang2_golden) elif task_backtranslation: # our goal is also to learn from regular cross-entropy loss, but since we do not have source tokens, # we will generate them ourselves with current model langs_src = self._backtranslation_src_langs.copy() langs_src.remove(lang_tgt) bt_losses = {} for lang_src in langs_src: curr_lang_pair = lang_src + "-" + lang_tgt # TODO: require to pass target language to forward on encoder outputs # We use greedy decoder because it was shown better for backtranslation with torch.no_grad(): predicted_indices = self._sequence_generator_greedy.generate( [self._model], lang2_tokens, self._get_true_pad_mask(lang2_tokens), self._end_index_lang2) model_input = self._strings_to_batch( self._indices_to_strings(predicted_indices), lang2_tokens, lang2_golden, curr_lang_pair) bt_losses['bt:' + curr_lang_pair] = self._forward_seq2seq( **model_input) else: raise ConfigurationError("No task have been detected") if task_translation: loss = self._coeff_translation * loss elif task_denoising: loss = self._coeff_denoising * loss elif task_backtranslation: loss = 0 for bt_loss in bt_losses.values(): loss += self._coeff_backtranslation * bt_loss output_dict["loss"] = loss elif mode_validation: output_dict["loss"] = self._coeff_translation * \ self._forward_seq2seq(lang_pair, lang1_tokens, lang2_tokens, lang2_golden) if self._bleu: predicted_indices = self._sequence_generator_greedy.generate( [self._model], lang1_tokens, self._get_true_pad_mask(lang1_tokens), self._end_index_lang2) predicted_strings = self._indices_to_strings(predicted_indices) golden_strings = self._indices_to_strings( lang2_tokens["tokens"]) golden_strings = self._remove_pad_eos(golden_strings) print(golden_strings, predicted_strings) self._bleu(corpus_bleu(golden_strings, predicted_strings)) elif mode_prediction: # TODO: pass target language (in the fseq_encoder append embedded target language to the encoder out) predicted_indices = self._sequence_generator_beam.generate( [self._model], lang1_tokens, self._get_true_pad_mask(lang1_tokens), self._end_index_lang2) output_dict["predicted_indices"] = predicted_indices output_dict["predicted_strings"] = self._indices_to_strings( predicted_indices) return output_dict def _get_true_pad_mask(self, indexed_input): mask = util.get_text_field_mask(indexed_input) # TODO: account for cases when text field mask doesn't work, like BERT return mask def _remove_pad_eos(self, golden_strings): tmp = [] for x in golden_strings: tmp.append( list( filter( lambda a: a != DEFAULT_PADDING_TOKEN and a != END_SYMBOL, x))) return tmp def _convert_to_sentences(self, golden_strings, predicted_strings): golden_strings_nopad = [] for s in golden_strings: s_nopad = list(filter(lambda t: t != DEFAULT_PADDING_TOKEN, s)) s_nopad = " ".join(s_nopad) golden_strings_nopad.append(s_nopad) predicted_strings = [" ".join(s) for s in predicted_strings] return golden_strings_nopad, predicted_strings def _forward_seq2seq( self, lang_pair: List[str], source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor], target_golden: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]: source_tokens_padding_mask = self._get_true_pad_mask(source_tokens) encoder_out = self._encoder.forward(source_tokens, source_tokens_padding_mask) logits, _ = self._decoder.forward(target_tokens["tokens"], encoder_out) loss = self._get_ce_loss(logits, target_golden) return loss def _get_ce_loss(self, logits, golden): target_mask = util.get_text_field_mask(golden) loss = util.sequence_cross_entropy_with_logits( logits, golden["golden_tokens"], target_mask, label_smoothing=self._label_smoothing) return loss def _indices_to_strings(self, indices: torch.Tensor): all_predicted_tokens = [] for hyp in indices: predicted_tokens = [ self.vocab.get_token_from_index( idx.item(), namespace=self._lang2_namespace) for idx in hyp ] all_predicted_tokens.append(predicted_tokens) return all_predicted_tokens def _strings_to_batch(self, source_tokens: List[List[str]], target_tokens: Dict[str, torch.Tensor], target_golden: Dict[str, torch.Tensor], lang_pair: str): """ Converts list of sentences which are itself lists of strings into Batch suitable for passing into model's forward function. TODO: Make sure the right device (CPU/GPU) is used. Predicted tokens might get copied on CPU in `self.decode` method... """ # convert source tokens into source tensor_dict instances = [] lang_pairs = [] for sentence in source_tokens: sentence = " ".join(sentence) instances.append(self._reader.string_to_instance(sentence)) lang_pairs.append(lang_pair) source_batch = Batch(instances) source_batch.index_instances(self.vocab) source_batch = source_batch.as_tensor_dict() model_input = { "source_tokens": source_batch["tokens"], "target_golden": target_golden, "target_tokens": target_tokens, "lang_pair": lang_pairs } return model_input @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu and not self.training: all_metrics.update({"BLEU": self._bleu.get_metric(reset=reset)}) return all_metrics
def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, cnn_size: int = 100, dropout_weight: float = 0.1, with_entity_embeddings: bool = True, sent_loss_weight: float = 1, attention_weight_fn: str = 'sigmoid', attention_aggregation_fn: str = 'max') -> None: regularizer = None super().__init__(vocab, regularizer) self.num_classes = self.vocab.get_vocab_size("labels") self.text_field_embedder = text_field_embedder self.dropout_weight = dropout_weight self.with_entity_embeddings = with_entity_embeddings self.sent_loss_weight = sent_loss_weight self.attention_weight_fn = attention_weight_fn self.attention_aggregation_fn = attention_aggregation_fn # instantiate position embedder pos_embed_output_size = 5 pos_embed_input_size = 2 * RelationInstancesReader.max_distance + 1 self.pos_embed = nn.Embedding(pos_embed_input_size, pos_embed_output_size) pos_embed_weights = np.array([range(pos_embed_input_size)] * pos_embed_output_size).T self.pos_embed.weight = nn.Parameter(torch.Tensor(pos_embed_weights)) d = cnn_size sent_encoder = CnnEncoder # TODO: should be moved to the config file cnn_output_size = d embedding_size = 300 # TODO: should be moved to the config file # instantiate sentence encoder self.cnn = sent_encoder(embedding_dim=(embedding_size + 2 * pos_embed_output_size), num_filters=cnn_size, ngram_filter_sizes=(2, 3, 4, 5), conv_layer_activation=torch.nn.ReLU(), output_dim=cnn_output_size) # dropout after word embedding self.dropout = nn.Dropout(p=self.dropout_weight) # given a sentence, returns its unnormalized attention weight self.attention_ff = nn.Sequential(nn.Linear(cnn_output_size, d), nn.ReLU(), nn.Linear(d, 1)) self.ff_before_alpha = nn.Sequential( nn.Linear(1, 50), nn.ReLU(), nn.Linear(50, 1), ) ff_input_size = cnn_output_size if self.with_entity_embeddings: ff_input_size += embedding_size # output layer self.ff = nn.Sequential(nn.Linear(ff_input_size, d), nn.ReLU(), nn.Linear(d, self.num_classes)) self.loss = torch.nn.BCEWithLogitsLoss( ) # sigmoid + binary cross entropy self.metrics = {} self.metrics['ap'] = MultilabelAveragePrecision( ) # average precision = AUC self.metrics['bag_loss'] = Average() # to display bag-level loss if self.sent_loss_weight > 0: self.metrics['sent_loss'] = Average( ) # to display sentence-level loss
class BertQA(Model): """ This class implements Minjoon Seo's `Bidirectional Attention Flow model <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_ for answering reading comprehension questions (ICLR 2017). The basic layout is pretty simple: encode words as a combination of word embeddings and a character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of attentions to put question information into the passage word representations (this is the only part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and do a softmax over span start and span end. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. num_highway_layers : ``int`` The number of highway layers to use in between embedding the input and passing it through the phrase layer. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. similarity_function : ``SimilarityFunction`` The similarity function that we will use when comparing encoded passage and question representations. modeling_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). mask_lstms : ``bool``, optional (default=True) If ``False``, we will skip passing the mask to the LSTM layers. This gives a ~2x speedup, with only a slight performance decrease, if any. We haven't experimented much with this yet, but have confirmed that we still get very similar performance with much faster training times. We still use the mask for all softmaxes, but avoid the shuffling that's required when using masking with pytorch LSTMs. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, sim_text_field_embedder: TextFieldEmbedder, loss_weights: Dict, sim_class_weights: List, pretrained_sim_path: str = None, use_scenario_encoding: bool = True, sim_pretraining: bool = False, dropout: float = 0.2, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(BertQA, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder if use_scenario_encoding: self._sim_text_field_embedder = sim_text_field_embedder self.loss_weights = loss_weights self.sim_class_weights = sim_class_weights self.use_scenario_encoding = use_scenario_encoding self.sim_pretraining = sim_pretraining if self.sim_pretraining and not self.use_scenario_encoding: raise ValueError( "When pretraining Scenario Interpretation Module, you should use it." ) embedding_dim = self._text_field_embedder.get_output_dim() self._action_predictor = torch.nn.Linear(embedding_dim, 4) self._sim_token_label_predictor = torch.nn.Linear(embedding_dim, 4) self._span_predictor = torch.nn.Linear(embedding_dim, 2) self._action_accuracy = CategoricalAccuracy() self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() self._span_loss_metric = Average() self._action_loss_metric = Average() self._sim_loss_metric = Average() self._sim_yes_f1 = F1Measure(2) self._sim_no_f1 = F1Measure(3) if use_scenario_encoding and pretrained_sim_path is not None: logger.info("Loading pretrained model..") self.load_state_dict(torch.load(pretrained_sim_path)) for param in self._sim_text_field_embedder.parameters(): param.requires_grad = False if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x initializer(self) def get_passage_representation(self, bert_output, bert_input): # Shape: (batch_size, bert_input_len) input_type_ids = self.get_input_type_ids( bert_input['bert-type-ids'], bert_input['bert-offsets'], self._text_field_embedder._token_embedders['bert']).float() # Shape: (batch_size, bert_input_len) input_mask = util.get_text_field_mask(bert_input).float() passage_mask = input_mask - input_type_ids # works only with one [SEP] # Shape: (batch_size, bert_input_len, embedding_dim) passage_representation = bert_output * passage_mask.unsqueeze(2) # Shape: (batch_size, passage_len, embedding_dim) passage_representation = passage_representation[:, passage_mask.sum( dim=0) > 0, :] # Shape: (batch_size, passage_len) passage_mask = passage_mask[:, passage_mask.sum(dim=0) > 0] return passage_representation, passage_mask def forward( self, # type: ignore bert_input: Dict[str, torch.LongTensor], sim_bert_input: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question tokens, passage tokens, original passage text, and token offsets into the passage for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ if self.use_scenario_encoding: # Shape: (batch_size, sim_bert_input_len_wp) sim_bert_input_token_labels_wp = sim_bert_input[ 'scenario_gold_encoding'] # Shape: (batch_size, sim_bert_input_len_wp, embedding_dim) sim_bert_output_wp = self._sim_text_field_embedder(sim_bert_input) # Shape: (batch_size, sim_bert_input_len_wp) sim_input_mask_wp = (sim_bert_input['bert'] != 0).float() # Shape: (batch_size, sim_bert_input_len_wp) sim_passage_mask_wp = sim_input_mask_wp - sim_bert_input[ 'bert-type-ids'].float() # works only with one [SEP] # Shape: (batch_size, sim_bert_input_len_wp, embedding_dim) sim_passage_representation_wp = sim_bert_output_wp * sim_passage_mask_wp.unsqueeze( 2) # Shape: (batch_size, passage_len_wp, embedding_dim) sim_passage_representation_wp = sim_passage_representation_wp[:, sim_passage_mask_wp .sum( dim =0 ) > 0, :] # Shape: (batch_size, passage_len_wp) sim_passage_token_labels_wp = sim_bert_input_token_labels_wp[:, sim_passage_mask_wp .sum( dim =0 ) > 0] # Shape: (batch_size, passage_len_wp) sim_passage_mask_wp = sim_passage_mask_wp[:, sim_passage_mask_wp.sum( dim=0) > 0] # Shape: (batch_size, passage_len_wp, 4) sim_token_logits_wp = self._sim_token_label_predictor( sim_passage_representation_wp) if span_start is not None: # during training and validation class_weights = torch.tensor(self.sim_class_weights, device=sim_token_logits_wp.device, dtype=torch.float) sim_loss = cross_entropy(sim_token_logits_wp.view(-1, 4), sim_passage_token_labels_wp.view(-1), ignore_index=0, weight=class_weights) self._sim_loss_metric(sim_loss.item()) self._sim_yes_f1(sim_token_logits_wp, sim_passage_token_labels_wp, sim_passage_mask_wp) self._sim_no_f1(sim_token_logits_wp, sim_passage_token_labels_wp, sim_passage_mask_wp) if self.sim_pretraining: return {'loss': sim_loss} if not self.sim_pretraining: # Shape: (batch_size, passage_len_wp) bert_input['scenario_encoding'] = (sim_token_logits_wp.argmax( dim=2)) * sim_passage_mask_wp.long() # Shape: (batch_size, bert_input_len_wp) bert_input_wp_len = bert_input['history_encoding'].size(1) if bert_input['scenario_encoding'].size(1) > bert_input_wp_len: # Shape: (batch_size, bert_input_len_wp) bert_input['scenario_encoding'] = bert_input[ 'scenario_encoding'][:, :bert_input_wp_len] else: batch_size = bert_input['scenario_encoding'].size(0) difference = bert_input_wp_len - bert_input[ 'scenario_encoding'].size(1) zeros = torch.zeros( batch_size, difference, dtype=bert_input['scenario_encoding'].dtype, device=bert_input['scenario_encoding'].device) # Shape: (batch_size, bert_input_len_wp) bert_input['scenario_encoding'] = torch.cat( [bert_input['scenario_encoding'], zeros], dim=1) # Shape: (batch_size, bert_input_len + 1, embedding_dim) bert_output = self._text_field_embedder(bert_input) # Shape: (batch_size, embedding_dim) pooled_output = bert_output[:, 0] # Shape: (batch_size, bert_input_len, embedding_dim) bert_output = bert_output[:, 1:, :] # Shape: (batch_size, passage_len, embedding_dim), (batch_size, passage_len) passage_representation, passage_mask = self.get_passage_representation( bert_output, bert_input) # Shape: (batch_size, 4) action_logits = self._action_predictor(pooled_output) # Shape: (batch_size, passage_len, 2) span_logits = self._span_predictor(passage_representation) # Shape: (batch_size, passage_len, 1), (batch_size, passage_len, 1) span_start_logits, span_end_logits = span_logits.split(1, dim=2) # Shape: (batch_size, passage_len) span_start_logits = span_start_logits.squeeze(2) # Shape: (batch_size, passage_len) span_end_logits = span_end_logits.squeeze(2) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = get_best_span(span_start_logits, span_end_logits) output_dict = { "pooled_output": pooled_output, "passage_representation": passage_representation, "action_logits": action_logits, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } if self.use_scenario_encoding: output_dict["sim_token_logits"] = sim_token_logits_wp # Compute the loss for training (and for validation) if span_start is not None: # Shape: (batch_size,) span_loss = nll_loss(util.masked_log_softmax( span_start_logits, passage_mask), span_start.squeeze(1), reduction='none') # Shape: (batch_size,) span_loss += nll_loss(util.masked_log_softmax( span_end_logits, passage_mask), span_end.squeeze(1), reduction='none') # Shape: (batch_size,) more_mask = (label == self.vocab.get_token_index( 'More', namespace="labels")).float() # Shape: (batch_size,) span_loss = (span_loss * more_mask).sum() / (more_mask.sum() + 1e-6) if more_mask.sum() > 1e-7: self._span_start_accuracy(span_start_logits, span_start.squeeze(1), more_mask) self._span_end_accuracy(span_end_logits, span_end.squeeze(1), more_mask) # Shape: (batch_size, 2) span_acc_mask = more_mask.unsqueeze(1).expand(-1, 2).long() self._span_accuracy(best_span, torch.cat([span_start, span_end], dim=1), span_acc_mask) action_loss = cross_entropy(action_logits, label) self._action_accuracy(action_logits, label) self._span_loss_metric(span_loss.item()) self._action_loss_metric(action_loss.item()) output_dict['loss'] = self.loss_weights[ 'span_loss'] * span_loss + self.loss_weights[ 'action_loss'] * action_loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if not self.training: # true during validation and test output_dict['best_span_str'] = [] batch_size = len(metadata) for i in range(batch_size): passage_text = metadata[i]['passage_text'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_str = passage_text[start_offset:end_offset] output_dict['best_span_str'].append(best_span_str) if 'gold_span' in metadata[i]: if metadata[i]['action'] == 'More': gold_span = metadata[i]['gold_span'] self._squad_metrics(best_span_str, [gold_span]) return output_dict def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: action_probs = softmax(output_dict['action_logits'], dim=1) output_dict['action_probs'] = action_probs predictions = action_probs.cpu().data.numpy() argmax_indices = numpy.argmax(predictions, axis=1) labels = [ self.vocab.get_token_from_index(x, namespace="labels") for x in argmax_indices ] output_dict['label'] = labels return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: if self.use_scenario_encoding: sim_loss = self._sim_loss_metric.get_metric(reset) _, _, yes_f1 = self._sim_yes_f1.get_metric(reset) _, _, no_f1 = self._sim_no_f1.get_metric(reset) if self.sim_pretraining: return {'sim_macro_f1': (yes_f1 + no_f1) / 2} try: action_acc = self._action_accuracy.get_metric(reset) except ZeroDivisionError: action_acc = 0 try: start_acc = self._span_start_accuracy.get_metric(reset) except ZeroDivisionError: start_acc = 0 try: end_acc = self._span_end_accuracy.get_metric(reset) except ZeroDivisionError: end_acc = 0 try: span_acc = self._span_accuracy.get_metric(reset) except ZeroDivisionError: span_acc = 0 exact_match, f1_score = self._squad_metrics.get_metric(reset) span_loss = self._span_loss_metric.get_metric(reset) action_loss = self._action_loss_metric.get_metric(reset) agg_metric = span_acc + action_acc * 0.45 metrics = { 'action_acc': action_acc, 'span_acc': span_acc, 'span_loss': span_loss, 'action_loss': action_loss, 'agg_metric': agg_metric } if self.use_scenario_encoding: metrics['sim_macro_f1'] = (yes_f1 + no_f1) / 2 if not self.training: # during validation metrics['em'] = exact_match metrics['f1'] = f1_score return metrics @staticmethod def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor: # We call the inputs "logits" - they could either be unnormalized logits or normalized log # probabilities. A log_softmax operation is a constant shifting of the entire logit # vector, so taking an argmax over either one gives the same result. if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() device = span_start_logits.device # (batch_size, passage_length, passage_length) span_log_probs = span_start_logits.unsqueeze( 2) + span_end_logits.unsqueeze(1) # Only the upper triangle of the span matrix is valid; the lower triangle has entries where # the span ends before it starts. span_log_mask = torch.triu( torch.ones((passage_length, passage_length), device=device)).log().unsqueeze(0) valid_span_log_probs = span_log_probs + span_log_mask # Here we take the span matrix and flatten it, then find the best span using argmax. We # can recover the start and end indices from this flattened list using simple modular # arithmetic. # (batch_size, passage_length * passage_length) best_spans = valid_span_log_probs.view(batch_size, -1).argmax(-1) span_start_indices = best_spans // passage_length span_end_indices = best_spans % passage_length return torch.stack([span_start_indices, span_end_indices], dim=-1) def get_input_type_ids(self, type_ids, offsets, embedder): "Converts (bsz, seq_len_wp) to (bsz, seq_len_wp) by indexing." batch_size = type_ids.size(0) full_seq_len = type_ids.size(1) if full_seq_len > embedder.max_pieces: # Recombine if we had used sliding window approach assert batch_size == 1 and type_ids.max() > 0 num_question_tokens = type_ids[0][:embedder.max_pieces].nonzero( ).size(0) select_indices = embedder.indices_to_select( full_seq_len, num_question_tokens) type_ids = type_ids[:, select_indices] range_vector = util.get_range_vector( batch_size, device=util.get_device_of(type_ids)).unsqueeze(1) type_ids = type_ids[range_vector, offsets] return type_ids
def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, max_decoding_steps: int, attention: Attention, schema_path: str = None, missing_alignment_int: int = 0, indexfield_padding_index: int = -1, beam_size: int = None, target_namespace: str = "tokens", target_embedding_dim: int = None, scheduled_sampling_ratio: float = 0., use_bleu: bool = True, emb_dropout: float = 0.0, dec_dropout: float = 0.0, attn_loss_lambda: float = 0.5, token_based_metric: Metric = None) -> None: super(AttnSupSeq2Seq, self).__init__(vocab) self._target_namespace = target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio self._indexfield_padding_index = indexfield_padding_index self._missing_alignment_int = missing_alignment_int # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={ pad_index, self._end_index, self._start_index }) else: self._bleu = None if token_based_metric: self._token_based_metric = token_based_metric else: self._token_based_metric = TokenSequenceAccuracy() # log attention supervision CE loss as a metric self._attn_sup_loss = Average() self._sql_metrics = schema_path is not None if self._sql_metrics: # SQL specific metrics: match between the templates free of schema constants, # and match between the schema constants self._schema_free_match = GlobalTemplAccuracy( schema_path=schema_path) self._kb_match = KnowledgeBaseConstsAccuracy( schema_path=schema_path) # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source vocab tokens. self._source_embedder = source_embedder self._emb_dropout = Dropout(p=emb_dropout) self._dec_dropout = Dropout(p=dec_dropout) self._attn_loss_lambda = attn_loss_lambda # Encodes the sequence of source embeddings into a sequence of hidden states. self._encoder = encoder num_classes = self.vocab.get_vocab_size(self._target_namespace) # Attention mechanism applied to the encoder output for each step. self._attention = attention self._attention._normalize = False # Dense embedding of vocab words in the target space. target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim( ) self._target_embedder = Embedding(num_classes, target_embedding_dim) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = self._encoder.get_output_dim() self._decoder_output_dim = self._encoder_output_dim # A weighted average over encoder outputs will be concatenated to the previous target embedding # to form the input to the decoder at each time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. # TODO (pradeep): Do not hardcode decoder cell type. self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes)
class AttnSupSeq2Seq(Model): """ Adaptation of the ``SimpleSeq2Seq`` class in allennlp_models, with auxiliary attention-supervision loss Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (`tokens`) or the target tokens can have a different namespace, in which case it needs to be specified as `target_namespace`. source_embedder : ``TextFieldEmbedder``, required Embedder for source side sequences encoder : ``Seq2SeqEncoder``, required The encoder of the "encoder/decoder" model max_decoding_steps : ``int`` Maximum length of decoded sequences. target_namespace : ``str``, optional (default = 'target_tokens') If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. target_embedding_dim : ``int``, optional (default = source_embedding_dim) You can specify an embedding dimensionality for the target side. If not, we'll use the same value as the source embedder's. attention : ``Attention``, optional (default = None) If you want to use attention to get a dynamic summary of the encoder outputs at each step of decoding, this is the function used to compute similarity between the decoder hidden state and encoder outputs. attention_function: ``SimilarityFunction``, optional (default = None) This is if you want to use the legacy implementation of attention. This will be deprecated since it consumes more memory than the specialized attention modules. beam_size : ``int``, optional (default = None) Width of the beam for beam search. If not specified, greedy decoding is used. scheduled_sampling_ratio : ``float``, optional (default = 0.) At each timestep during training, we sample a random number between 0 and 1, and if it is not less than this value, we use the ground truth labels for the whole batch. Else, we use the predictions from the previous time step for the whole batch. If this value is 0.0 (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not using target side ground truth labels. See the following paper for more information: `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al., 2015 <https://arxiv.org/abs/1506.03099>`_. use_bleu : ``bool``, optional (default = True) If True, the BLEU metric will be calculated during validation. """ def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, max_decoding_steps: int, attention: Attention, schema_path: str = None, missing_alignment_int: int = 0, indexfield_padding_index: int = -1, beam_size: int = None, target_namespace: str = "tokens", target_embedding_dim: int = None, scheduled_sampling_ratio: float = 0., use_bleu: bool = True, emb_dropout: float = 0.0, dec_dropout: float = 0.0, attn_loss_lambda: float = 0.5, token_based_metric: Metric = None) -> None: super(AttnSupSeq2Seq, self).__init__(vocab) self._target_namespace = target_namespace self._scheduled_sampling_ratio = scheduled_sampling_ratio self._indexfield_padding_index = indexfield_padding_index self._missing_alignment_int = missing_alignment_int # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access self._bleu = BLEU(exclude_indices={ pad_index, self._end_index, self._start_index }) else: self._bleu = None if token_based_metric: self._token_based_metric = token_based_metric else: self._token_based_metric = TokenSequenceAccuracy() # log attention supervision CE loss as a metric self._attn_sup_loss = Average() self._sql_metrics = schema_path is not None if self._sql_metrics: # SQL specific metrics: match between the templates free of schema constants, # and match between the schema constants self._schema_free_match = GlobalTemplAccuracy( schema_path=schema_path) self._kb_match = KnowledgeBaseConstsAccuracy( schema_path=schema_path) # At prediction time, we use a beam search to find the most likely sequence of target tokens. beam_size = beam_size or 1 self._max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) # Dense embedding of source vocab tokens. self._source_embedder = source_embedder self._emb_dropout = Dropout(p=emb_dropout) self._dec_dropout = Dropout(p=dec_dropout) self._attn_loss_lambda = attn_loss_lambda # Encodes the sequence of source embeddings into a sequence of hidden states. self._encoder = encoder num_classes = self.vocab.get_vocab_size(self._target_namespace) # Attention mechanism applied to the encoder output for each step. self._attention = attention self._attention._normalize = False # Dense embedding of vocab words in the target space. target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim( ) self._target_embedder = Embedding(num_classes, target_embedding_dim) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = self._encoder.get_output_dim() self._decoder_output_dim = self._encoder_output_dim # A weighted average over encoder outputs will be concatenated to the previous target embedding # to form the input to the decoder at each time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. # TODO (pradeep): Do not hardcode decoder cell type. self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes) _, output_projections, state = self._prepare_output_projections( last_predictions, state) # shape: (group_size, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state @overrides def forward_on_instances( self, instances: List[Instance]) -> List[Dict[str, numpy.ndarray]]: """ Takes a list of :class:`~allennlp.data.instance.Instance`s, converts that text into arrays using this model's :class:`Vocabulary`, passes those arrays through :func:`self.forward()` and :func:`self.decode()` (which by default does nothing) and returns the result. Before returning the result, we convert any ``torch.Tensors`` into numpy arrays and separate the batched output into a list of individual dicts per instance. Note that typically this will be faster on a GPU (and conditionally, on a CPU) than repeated calls to :func:`forward_on_instance`. Parameters ---------- instances : List[Instance], required The instances to run the model on. cuda_device : int, required The GPU device to use. -1 means use the CPU. Returns ------- A list of the models output for each instance. """ batch_size = len(instances) with torch.no_grad(): cuda_device = self._get_prediction_device() dataset = Batch(instances) dataset.index_instances(self.vocab) model_input = util.move_to_device(dataset.as_tensor_dict(), cuda_device) outputs = self.decode(self(**model_input)) instance_separated_output: List[Dict[str, numpy.ndarray]] = [ {} for _ in dataset.instances ] for name, output in list(outputs.items()): if isinstance(output, torch.Tensor): # NOTE(markn): This is a hack because 0-dim pytorch tensors are not iterable. # This occurs with batch size 1, because we still want to include the loss in that case. if output.dim() == 0: output = output.unsqueeze(0) if output.size(0) != batch_size: self._maybe_warn_for_unseparable_batches(name) continue output = output.detach().cpu().numpy() elif len(output) != batch_size: self._maybe_warn_for_unseparable_batches(name) continue for instance_output, batch_element in zip( instance_separated_output, output): instance_output[name] = batch_element for instance_output, instance_input in zip( instance_separated_output, instances): for field in instance_input.fields: try: instance_output[field] = instance_input.fields[ field].tokens except Exception as e: continue return instance_separated_output @overrides def forward( self, # type: ignore source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None, alignment_sequence: torch.Tensor = None ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Make foward pass with decoder logic for producing the entire target sequence. Parameters ---------- source_tokens : ``Dict[str, torch.LongTensor]`` The output of `TextField.as_array()` applied on the source `TextField`. This will be passed through a `TextFieldEmbedder` and then through an encoder. target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) Output of `Textfield.as_array()` applied on target `TextField`. We assume that the target tokens are also represented as a `TextField`. alignment_sequence : ``Dict[str, torch.LongTensor]``, optional (default = None) Output of `Textfield.as_array()` applied on alignemnet `TextField`. Returns ------- Dict[str, torch.Tensor] """ state = self._encode(source_tokens) if target_tokens: state = self._init_decoder_state(state) # Remove the trailing dimension (from ListField[ListField[IndexField]]). alignment_sequence = alignment_sequence.squeeze(-1) # The `_forward_loop` decodes the input sequence and computes the loss during training # and validation. output_dict = self._forward_loop(state, target_tokens, alignment_sequence) else: output_dict = {} if not self.training: state = self._init_decoder_state(state) predictions = self._forward_beam_search(state) output_dict.update(predictions) if target_tokens: if self._bleu: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._bleu(best_predictions, target_tokens["tokens"]) predicted_tokens = self.decode(output_dict)["predicted_tokens"] target_tokens_str = self.decode_target_tokens(target_tokens) if self._token_based_metric: self._token_based_metric(predicted_tokens, target_tokens_str) if self._sql_metrics: self._kb_match(predicted_tokens, target_tokens_str) self._schema_free_match(predicted_tokens, target_tokens_str) # In case of attention coverage mechanism, reset the coverage vector after every batch... try: self._attention.reset_coverage_vector() except Exception: pass return output_dict def decode_target_tokens(self, target_tokens): target_indices = target_tokens['tokens'].detach().cpu().numpy() target_tokens_output = [] for i in range(target_indices.shape[0]): cur_target_indices = target_indices[i] cur_target_indices = list(cur_target_indices) if self._end_index in cur_target_indices: cur_target_indices = cur_target_indices[:cur_target_indices. index(self._end_index)] if self._start_index in cur_target_indices: cur_target_indices = cur_target_indices[ cur_target_indices.index(self._start_index) + 1:] target_tokens_str = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in cur_target_indices ] target_tokens_output.append(target_tokens_str) return target_tokens_output @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Finalize predictions. This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives within the ``forward`` method. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): predicted_indices = predicted_indices.detach().cpu().numpy() all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _encode( self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # shape: (batch_size, max_input_sequence_length, encoder_input_dim) embedded_input = self._source_embedder(source_tokens) # shape: (batch_size, max_input_sequence_length) source_mask = util.get_text_field_mask(source_tokens) # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = self._encoder(embedded_input, source_mask) encoder_outputs = self._emb_dropout(encoder_outputs) return { "source_mask": source_mask, "encoder_outputs": encoder_outputs, } def _init_decoder_state( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: batch_size = state["source_mask"].size(0) # shape: (batch_size, encoder_output_dim) final_encoder_output = util.get_final_encoder_states( state["encoder_outputs"], state["source_mask"], self._encoder.is_bidirectional()) # Initialize the decoder hidden state with the final output of the encoder. # shape: (batch_size, decoder_output_dim) state["decoder_hidden"] = final_encoder_output # shape: (batch_size, decoder_output_dim) state["decoder_context"] = state["encoder_outputs"].new_zeros( batch_size, self._decoder_output_dim) return state def _forward_loop( self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None, alignment_sequence: torch.Tensor = None ) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size()[0] if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full((batch_size, ), fill_value=self._start_index) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] step_attn_weights: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand( 1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size,) input_choices = last_predictions elif not target_tokens: # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] # shape: (batch_size, num_classes) # shape: (batch_size, input_max_size) input_weights, output_projections, state = self._prepare_output_projections( input_choices, state) step_attn_weights.append(input_weights.unsqueeze(1)) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape: (batch_size, num_classes) class_probabilities = F.softmax(output_projections, dim=-1) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(class_probabilities, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) # shape: (batch_size, num_decoding_steps, max_input_sequence_length) attention_input_weights = torch.cat(step_attn_weights[:-1], 1) output_dict = { "predictions": predictions, 'attention_input_weights': attention_input_weights } if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # shape: (batch_size, num_decoding_steps, max_input_sequence_length) alignment_mask = self._get_alignment_mask(alignment_sequence) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) attn_sup_loss = self._get_attn_sup_loss(attention_input_weights, alignment_mask, alignment_sequence) self._attn_sup_loss(attn_sup_loss.detach().cpu().item()) output_dict["loss"] = loss + self._attn_loss_lambda * attn_sup_loss return output_dict def _forward_beam_search( self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Make forward pass during prediction using a beam search.""" batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self._start_index) # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.take_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict def _prepare_output_projections(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]: # pylint: disable=line-too-long """ Decode current state and last prediction to produce produce projections into the target space, which can then be used to get probabilities of each target token for the next step. Inputs are the same as for `take_step()`. """ # shape: (group_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (group_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (group_size, decoder_output_dim) decoder_hidden = state["decoder_hidden"] # shape: (group_size, decoder_output_dim) decoder_context = state["decoder_context"] # shape: (group_size, target_embedding_dim) embedded_input = self._target_embedder(last_predictions) # shape: (group_size, encoder_output_dim) attended_input, input_weights = self._prepare_attended_input( decoder_hidden, encoder_outputs, source_mask) # shape: (group_size, decoder_output_dim + target_embedding_dim) decoder_input = torch.cat((attended_input, embedded_input), -1) decoder_input = self._dec_dropout(decoder_input) # shape (decoder_hidden): (batch_size, decoder_output_dim) # shape (decoder_context): (batch_size, decoder_output_dim) decoder_hidden, decoder_context = self._decoder_cell( decoder_input, (decoder_hidden, decoder_context)) state["decoder_hidden"] = decoder_hidden state["decoder_context"] = decoder_context # shape: (group_size, num_classes) output_projections = self._output_projection_layer( self._dec_dropout(decoder_hidden)) return input_weights, output_projections, state def _prepare_attended_input( self, decoder_hidden_state: torch.LongTensor = None, encoder_outputs: torch.LongTensor = None, encoder_outputs_mask: torch.LongTensor = None ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply attention over encoder outputs and decoder state.""" # Ensure mask is also a FloatTensor. Or else the multiplication within # attention will complain. # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs_mask = encoder_outputs_mask.float() # shape: (batch_size, max_input_sequence_length) input_logits = self._attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask) # the attention mechanism returns the logits that are necessary for attention supervision loss, # so we normalize it here input_weights = masked_softmax(input_logits, encoder_outputs_mask) # shape: (batch_size, encoder_output_dim) attended_input = util.weighted_sum(encoder_outputs, input_weights) return attended_input, input_logits @staticmethod def _get_attn_sup_loss(attn_weights: torch.Tensor, alignment_mask: torch.Tensor, alignment_sequence: torch.Tensor) -> torch.Tensor: """ Compute the attention supervision CE loss. For each step, take the index of the aligned """ # shape: (batch_size, max_decoding_steps, max_input_seq_length attn_weights = attn_weights.float() alignment_sequence[alignment_sequence == -1] = 0 # for each attn_weights[batch_index, step_index, :] I want to choose the index of # alignment_sequence[batch_index, step_index] return util.sequence_cross_entropy_with_logits(attn_weights, alignment_sequence, alignment_mask) def _get_alignment_mask(self, alignment_sequence): """ The alignment mask includes the target mask + mask on steps that don't have alignment shape: batch_size, max_steps, max_input """ pad_mask = alignment_sequence != self._indexfield_padding_index missing_mask = alignment_sequence != self._missing_alignment_int return pad_mask * missing_mask @staticmethod def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor) -> torch.Tensor: """ Compute loss. Takes logits (unnormalized outputs from the decoder) of size (batch_size, num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross entropy loss while taking the mask into account. The length of ``targets`` is expected to be greater than that of ``logits`` because the decoder does not need to compute the output corresponding to the last timestep of ``targets``. This method aligns the inputs appropriately to compute the loss. During training, we want the logit corresponding to timestep i to be similar to the target token from timestep i + 1. That is, the targets should be shifted by one timestep for appropriate comparison. Consider a single example where the target has 3 words, and padding is to 7 tokens. The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P> and the mask would be 1 1 1 1 1 0 0 and let the logits be l1 l2 l3 l4 l5 l6 We actually need to compare: the sequence w1 w2 w3 <E> <P> <P> with masks 1 1 1 1 0 0 against l1 l2 l3 l4 l5 l6 (where the input was) <S> w1 w2 w3 <E> <P> """ # shape: (batch_size, num_decoding_steps) relevant_targets = targets[:, 1:].contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask[:, 1:].contiguous() return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if not self.training: if self._bleu: all_metrics.update(self._bleu.get_metric(reset=reset)) all_metrics.update( self._token_based_metric.get_metric(reset=reset)) if self._sql_metrics: all_metrics.update(self._kb_match.get_metric(reset=reset)) all_metrics.update( self._schema_free_match.get_metric(reset=reset)) all_metrics['attn_sup_loss'] = self._attn_sup_loss.get_metric( reset=reset) return all_metrics