コード例 #1
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        contextualizer: Seq2SeqEncoder,
        hparams: Dict,
    ) -> None:
        super().__init__(vocab)
        self.text_field_embedder = text_field_embedder

        self.contextualizer = contextualizer
        self.bidirectional = contextualizer.is_bidirectional()

        if self.bidirectional:
            self.forward_dim = contextualizer.get_output_dim() // 2
        else:
            self.forward_dim = contextualizer.get_output_dim()

        dropout = hparams["dropout"]
        if dropout:
            self.dropout = torch.nn.Dropout(dropout)
        else:
            self.dropout = lambda x: x

        self.hidden2chord = torch.nn.Sequential(
            torch.nn.Linear(self.forward_dim, hparams["fc_hidden_dim"]),
            torch.nn.ReLU(True),
            torch.nn.Linear(hparams["fc_hidden_dim"], vocab.get_vocab_size()),
        )
        self.perplexity = PerplexityCustom()
        self.accuracy = CategoricalAccuracy()
        self.real_loss = Average()

        self.similarity_matrix = hparams["similarity_matrix"]
        self.training_mode = hparams["training_mode"]

        self.T_initial = hparams["T_initial"]
        self.T = self.T_initial
        self.decay_rate = hparams["decay_rate"]

        self.batches_per_epoch = hparams["batches_per_epoch"]
        self.epoch = 0
        self.batch_counter = 0
コード例 #2
0
ファイル: gnli.py プロジェクト: anthonywchen/generative-nli
    def __init__(
        self,
        pretrained_model: str,
        discriminative_loss_weight: float = 0,
        vocab: Vocabulary = Vocabulary(),
        softmax_over_vocab: bool = False,
        initializer: InitializerApplicator = InitializerApplicator()
    ) -> None:
        super(GNLI, self).__init__(vocab)
        # Check the arguments of `__init__()`.
        assert pretrained_model in ['bart.large']
        assert discriminative_loss_weight >= 0 and discriminative_loss_weight <= 1

        # Load in BART and extend the embeddings layer by three for the label embeddings.
        self._bart = torch.hub.load('pytorch/fairseq', pretrained_model).model
        self._extend_embeddings()

        # Ignore padding indices when calculating generative loss.
        assert self._bart.encoder.padding_idx == 1
        self._generative_loss_fn = torch.nn.CrossEntropyLoss(
            ignore_index=self._bart.encoder.padding_idx)
        self._discriminative_loss_fn = torch.nn.NLLLoss()
        self._discriminative_loss_weight = discriminative_loss_weight

        self._softmax_over_vocab = softmax_over_vocab
        if self._softmax_over_vocab:
            self.effective_vocab_size = self.vocab_size
        else:
            self.effective_vocab_size = self.vocab_size + self.label_size

        self.metrics = {
            'accuracy': CategoricalAccuracy(),
            'disc_loss': Average(),
            'gen_loss': Average()
        }

        initializer(self)
        number_params = sum([
            numpy.prod(p.size()) for p in list(self.parameters())
            if p.requires_grad
        ])
        logger.info('Number of trainable model parameters: %d', number_params)
コード例 #3
0
ファイル: bert_qa.py プロジェクト: pombredanne/UrcaNet
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 sim_text_field_embedder: TextFieldEmbedder,
                 loss_weights: Dict,
                 sim_class_weights: List,
                 pretrained_sim_path: str = None,
                 use_scenario_encoding: bool = True,
                 sim_pretraining: bool = False,
                 dropout: float = 0.2,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BertQA, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        if use_scenario_encoding:
            self._sim_text_field_embedder = sim_text_field_embedder
        self.loss_weights = loss_weights
        self.sim_class_weights = sim_class_weights
        self.use_scenario_encoding = use_scenario_encoding
        self.sim_pretraining = sim_pretraining

        if self.sim_pretraining and not self.use_scenario_encoding:
            raise ValueError(
                "When pretraining Scenario Interpretation Module, you should use it."
            )

        embedding_dim = self._text_field_embedder.get_output_dim()
        self._action_predictor = torch.nn.Linear(embedding_dim, 4)
        self._sim_token_label_predictor = torch.nn.Linear(embedding_dim, 4)
        self._span_predictor = torch.nn.Linear(embedding_dim, 2)
        self._action_accuracy = CategoricalAccuracy()
        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        self._span_loss_metric = Average()
        self._action_loss_metric = Average()
        self._sim_loss_metric = Average()
        self._sim_yes_f1 = F1Measure(2)
        self._sim_no_f1 = F1Measure(3)

        if use_scenario_encoding and pretrained_sim_path is not None:
            logger.info("Loading pretrained model..")
            self.load_state_dict(torch.load(pretrained_sim_path))
            for param in self._sim_text_field_embedder.parameters():
                param.requires_grad = False

        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x

        initializer(self)
コード例 #4
0
class Cpm(Model):
    """
    The ``Cpm`` applies a "contextualizing"
    ``Seq2SeqEncoder`` to uncontextualized embeddings, using a ``torch.nn.functional.kl_div``
    module to compute the language modeling loss.
    If bidirectional is True,  the language model is trained to predict the next and
    previous tokens for each token in the input. In this case, the contextualizer must
    be bidirectional. If bidirectional is False, the language model is trained to only
    predict the next token for each token in the input; the contextualizer should also
    be unidirectional.
    If your language model is bidirectional, it is IMPORTANT that your bidirectional
    ``Seq2SeqEncoder`` contextualizer does not do any "peeking ahead". That is, for its
    forward direction it should only consider embeddings at previous timesteps, and for
    its backward direction only embeddings at subsequent timesteps. Similarly, if your
    language model is unidirectional, the unidirectional contextualizer should only
    consider embeddings at previous timesteps. If this condition is not met, your
    language model is cheating.
    Parameters
    ----------
    vocab: ``Vocabulary``
    text_field_embedder: ``TextFieldEmbedder``
        Used to embed the indexed tokens we get in ``forward``.
    contextualizer: ``Seq2SeqEncoder``
        Used to "contextualize" the embeddings. As described above,
        this encoder must not cheat by peeking ahead.
    dropout: ``float``, optional (default: None)
        If specified, dropout is applied to the contextualized embeddings before computation of
        the softmax. The contextualized embeddings themselves are returned without dropout.
    bidirectional: ``bool``, optional (default: False)
        Train a bidirectional language model, where the contextualizer
        is used to predict the next and previous token for each input token.
        This must match the bidirectionality of the contextualizer.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        contextualizer: Seq2SeqEncoder,
        hparams: Dict,
    ) -> None:
        super().__init__(vocab)
        self.text_field_embedder = text_field_embedder

        self.contextualizer = contextualizer
        self.bidirectional = contextualizer.is_bidirectional()

        if self.bidirectional:
            self.forward_dim = contextualizer.get_output_dim() // 2
        else:
            self.forward_dim = contextualizer.get_output_dim()

        dropout = hparams["dropout"]
        if dropout:
            self.dropout = torch.nn.Dropout(dropout)
        else:
            self.dropout = lambda x: x

        self.hidden2chord = torch.nn.Sequential(
            torch.nn.Linear(self.forward_dim, hparams["fc_hidden_dim"]),
            torch.nn.ReLU(True),
            torch.nn.Linear(hparams["fc_hidden_dim"], vocab.get_vocab_size()),
        )
        self.perplexity = PerplexityCustom()
        self.accuracy = CategoricalAccuracy()
        self.real_loss = Average()

        self.similarity_matrix = hparams["similarity_matrix"]
        self.training_mode = hparams["training_mode"]

        self.T_initial = hparams["T_initial"]
        self.T = self.T_initial
        self.decay_rate = hparams["decay_rate"]

        self.batches_per_epoch = hparams["batches_per_epoch"]
        self.epoch = 0
        self.batch_counter = 0

    def num_layers(self) -> int:
        """
        Returns the depth of this LM. That is, how many layers the contextualizer has plus one for
        the non-contextual layer.
        """
        if hasattr(self.contextualizer, "num_layers"):
            return self.contextualizer.num_layers + 1
        else:
            raise NotImplementedError(
                f"Contextualizer of type {type(self.contextualizer)} " +
                "does not report how many layers it has.")

    def loss_helper(self, direction_embeddings: torch.Tensor,
                    direction_targets: torch.Tensor):
        mask = direction_targets > 0
        # we need to subtract 1 to undo the padding id since the softmax
        # does not include a padding dimension

        # shape (batch_size * timesteps, )
        non_masked_targets = direction_targets.masked_select(mask)

        # shape (batch_size * timesteps, embedding_dim)
        non_masked_embeddings = direction_embeddings.masked_select(
            mask.unsqueeze(-1)).view(-1, self.forward_dim)
        # note: need to return average loss across forward and backward
        # directions, but total sum loss across all batches.
        # Assuming batches include full sentences, forward and backward
        # directions have the same number of samples, so sum up loss
        # here then divide by 2 just below
        probs = torch.nn.functional.log_softmax(
            self.hidden2chord(non_masked_embeddings), dim=-1)

        real_loss = torch.nn.functional.nll_loss(probs,
                                                 non_masked_targets,
                                                 reduction="sum")
        # transform targets into probability distributions using Embedding
        # then compute loss using torch.nn.functional.kl_div
        if self.training:
            if self.training_mode == TM_ONE_HOT:
                train_loss = real_loss
            elif self.training_mode == TM_NO:
                target_distributions = self.similarity_matrix(
                    non_masked_targets)
                train_loss = torch.nn.functional.kl_div(probs,
                                                        target_distributions,
                                                        reduction="sum")
            elif self.training_mode == TM_FIXED or self.training_mode == TM_DECREASED:
                target_distributions = self.similarity_matrix(
                    non_masked_targets)
                target_distributions = torch.nn.functional.softmax(
                    target_distributions / self.T, dim=1)
                train_loss = torch.nn.functional.kl_div(probs,
                                                        target_distributions,
                                                        reduction="sum")
            else:
                raise ValueError("Unknown training mode: {}".format(
                    self.training_mode))
        else:
            train_loss = real_loss
        return train_loss, real_loss

    @overrides
    def forward(
        self,
        input_tokens: Dict[str, torch.LongTensor],
        forward_output_tokens: Dict[str, torch.LongTensor],
        backward_output_tokens: Dict[str, torch.LongTensor] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Computes the averaged forward (and backward, if language model is bidirectional)
        LM loss from the batch.
        Returns
        -------
        Dict with keys:
        ``'loss'``: ``torch.Tensor``
            forward negative log likelihood, or the average of forward/backward
            if language model is bidirectional
        ``'forward_loss'``: ``torch.Tensor``
            forward direction negative log likelihood
        ``'backward_loss'``: ``torch.Tensor`` or ``None``
            backward direction negative log likelihood. If language model is not
            bidirectional, this is ``None``.
        ``'contextual_embeddings'``: ``Union[torch.Tensor, List[torch.Tensor]]``
            (batch_size, timesteps, embed_dim) tensor of top layer contextual representations or
            list of all layers. No dropout applied.
        ``'noncontextual_token_embeddings'``: ``torch.Tensor``
            (batch_size, timesteps, token_embed_dim) tensor of bottom layer noncontextual
            representations
        ``'mask'``: ``torch.Tensor``
            (batch_size, timesteps) mask for the embeddings
        """
        self.batch_counter += 1
        if self.batch_counter % self.batches_per_epoch == 0:
            self.epoch += 1
            if self.training_mode == TM_DECREASED:
                self.T *= 1 / (1 + self.decay_rate * self.epoch)
                if self.T < 1e-20:
                    self.T = 1e-20

        mask = get_text_field_mask(input_tokens)

        # shape (batch_size, timesteps, embedding_size)
        embeddings = self.text_field_embedder(input_tokens)

        contextual_embeddings = self.contextualizer(embeddings, mask)
        contextual_embeddings_with_dropout = self.dropout(
            contextual_embeddings)

        if self.bidirectional:
            forward_embeddings, backward_embeddings = contextual_embeddings_with_dropout.chunk(
                2, -1)
            backward_logits = self.hidden2chord(backward_embeddings)
        else:
            forward_embeddings = contextual_embeddings_with_dropout
            backward_logits = None
        forward_logits = self.hidden2chord(forward_embeddings)

        forward_targets = forward_output_tokens.get("tokens")
        if self.bidirectional:
            backward_targets = backward_output_tokens.get("tokens")

        # compute loss
        forward_loss, forward_real_loss = self.loss_helper(
            forward_embeddings, forward_targets)
        if self.bidirectional:
            backward_loss, backward_real_loss = self.loss_helper(
                backward_embeddings, backward_targets)
        else:
            backward_loss, backward_real_loss = None, None

        return_dict = {}

        num_targets = torch.sum((forward_targets > 0).long())
        if num_targets > 0:
            if self.bidirectional:
                average_loss = (0.5 * (forward_loss + backward_loss) /
                                num_targets.float())
                average_real_loss = (0.5 *
                                     (forward_real_loss + backward_real_loss) /
                                     num_targets.float())
            else:
                average_loss = forward_loss / num_targets.float()
                average_real_loss = forward_real_loss / num_targets.float()
        else:
            average_loss = torch.tensor(0.0).to(forward_targets.device)
            average_real_loss = torch.tensor(0.0).to(forward_targets.device)

        self.perplexity(average_real_loss)
        self.accuracy(forward_logits, forward_targets, mask)
        self.real_loss(average_real_loss)

        return_dict.update({"loss": average_loss})

        return_dict.update({
            # Note: These embeddings do not have dropout applied.
            "contextual_embeddings": contextual_embeddings,
            "noncontextual_token_embeddings": embeddings,
            "forward_logits": forward_logits,
            "backward_logits": backward_logits,
            "mask": mask,
        })

        return return_dict

    def get_metrics(self, reset: bool = False):
        return {
            "perplexity": self.perplexity.get_metric(reset=reset),
            "accuracy": self.accuracy.get_metric(reset=reset),
            "real_loss": float(self.real_loss.get_metric(reset=reset)),
        }
コード例 #5
0
    def __init__(self,
                 vocab: Vocabulary,
                 dataset_reader: DatasetReader,
                 source_embedder: TextFieldEmbedder,
                 lang2_namespace: str = "tokens",
                 use_bleu: bool = True) -> None:
        super().__init__(vocab)
        self._lang1_namespace = lang2_namespace  # TODO: DO NOT HARDCODE IT
        self._lang2_namespace = lang2_namespace

        # TODO: do not hardcore this
        self._backtranslation_src_langs = ["en", "ru"]
        self._coeff_denoising = 1
        self._coeff_backtranslation = 1
        self._coeff_translation = 1

        self._label_smoothing = 0.1

        self._pad_index_lang1 = vocab.get_token_index(DEFAULT_PADDING_TOKEN,
                                                      self._lang1_namespace)
        self._oov_index_lang1 = vocab.get_token_index(DEFAULT_OOV_TOKEN,
                                                      self._lang1_namespace)
        self._end_index_lang1 = self.vocab.get_token_index(
            END_SYMBOL, self._lang1_namespace)

        self._pad_index_lang2 = vocab.get_token_index(DEFAULT_PADDING_TOKEN,
                                                      self._lang2_namespace)
        self._oov_index_lang2 = vocab.get_token_index(DEFAULT_OOV_TOKEN,
                                                      self._lang2_namespace)
        self._end_index_lang2 = self.vocab.get_token_index(
            END_SYMBOL, self._lang2_namespace)

        self._reader = dataset_reader
        self._langs_list = self._reader._langs_list
        self._ae_steps = self._reader._ae_steps
        self._bt_steps = self._reader._bt_steps
        self._para_steps = self._reader._para_steps

        if use_bleu:
            self._bleu = Average()
        else:
            self._bleu = None

        args = ArgsStub()

        transformer_iwslt_de_en(args)

        # build encoder
        if not hasattr(args, 'max_source_positions'):
            args.max_source_positions = 1024
        if not hasattr(args, 'max_target_positions'):
            args.max_target_positions = 1024

        # Dense embedding of source vocab tokens.
        self._source_embedder = source_embedder

        # Dense embedding of vocab words in the target space.
        num_tokens_lang1 = self.vocab.get_vocab_size(self._lang1_namespace)
        num_tokens_lang2 = self.vocab.get_vocab_size(self._lang2_namespace)

        args.share_decoder_input_output_embed = False  # TODO implement shared embeddings

        lang1_dict = DictStub(num_tokens=num_tokens_lang1,
                              pad=self._pad_index_lang1,
                              unk=self._oov_index_lang1,
                              eos=self._end_index_lang1)

        lang2_dict = DictStub(num_tokens=num_tokens_lang2,
                              pad=self._pad_index_lang2,
                              unk=self._oov_index_lang2,
                              eos=self._end_index_lang2)

        # instantiate fairseq classes
        emb_golden_tokens = FairseqEmbedding(num_tokens_lang2,
                                             args.decoder_embed_dim,
                                             self._pad_index_lang2)

        self._encoder = TransformerEncoder(args, lang1_dict,
                                           self._source_embedder)
        self._decoder = TransformerDecoder(args, lang2_dict, emb_golden_tokens)
        self._model = TransformerModel(self._encoder, self._decoder)

        # TODO: do not hardcode max_len_b and beam size
        self._sequence_generator_greedy = FairseqBeamSearchWrapper(
            SequenceGenerator(tgt_dict=lang2_dict, beam_size=1, max_len_b=20))
        self._sequence_generator_beam = FairseqBeamSearchWrapper(
            SequenceGenerator(tgt_dict=lang2_dict, beam_size=7, max_len_b=20))
コード例 #6
0
class UnsupervisedTranslation(Model):
    """
    This ``SimpleSeq2Seq`` class is a :class:`Model` which takes a sequence, encodes it, and then
    uses the encoded representations to decode another sequence.  You can use this as the basis for
    a neural machine translation system, an abstractive summarization system, or any other common
    seq2seq problem.  The model here is simple, but should be a decent starting place for
    implementing recent models for these tasks.

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        Vocabulary containing source and target vocabularies. They may be under the same namespace
        (`tokens`) or the target tokens can have a different namespace, in which case it needs to
        be specified as `target_namespace`.
    source_embedder : ``TextFieldEmbedder``, required
        Embedder for source side sequences
    encoder : ``Seq2SeqEncoder``, required
        The encoder of the "encoder/decoder" model
    max_decoding_steps : ``int``
        Maximum length of decoded sequences.
    target_namespace : ``str``, optional (default = 'target_tokens')
        If the target side vocabulary is different from the source side's, you need to specify the
        target's namespace here. If not, we'll assume it is "tokens", which is also the default
        choice for the source side, and this might cause them to share vocabularies.
    target_embedding_dim : ``int``, optional (default = source_embedding_dim)
        You can specify an embedding dimensionality for the target side. If not, we'll use the same
        value as the source embedder's.
    attention : ``Attention``, optional (default = None)
        If you want to use attention to get a dynamic summary of the encoder outputs at each step
        of decoding, this is the function used to compute similarity between the decoder hidden
        state and encoder outputs.
    attention_function: ``SimilarityFunction``, optional (default = None)
        This is if you want to use the legacy implementation of attention. This will be deprecated
        since it consumes more memory than the specialized attention modules.
    beam_size : ``int``, optional (default = None)
        Width of the beam for beam search. If not specified, greedy decoding is used.
    scheduled_sampling_ratio : ``float``, optional (default = 0.)
        At each timestep during training, we sample a random number between 0 and 1, and if it is
        not less than this value, we use the ground truth labels for the whole batch. Else, we use
        the predictions from the previous time step for the whole batch. If this value is 0.0
        (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not
        using target side ground truth labels.  See the following paper for more information:
        `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al.,
        2015 <https://arxiv.org/abs/1506.03099>`_.
    use_bleu : ``bool``, optional (default = True)
        If True, the BLEU metric will be calculated during validation.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 dataset_reader: DatasetReader,
                 source_embedder: TextFieldEmbedder,
                 lang2_namespace: str = "tokens",
                 use_bleu: bool = True) -> None:
        super().__init__(vocab)
        self._lang1_namespace = lang2_namespace  # TODO: DO NOT HARDCODE IT
        self._lang2_namespace = lang2_namespace

        # TODO: do not hardcore this
        self._backtranslation_src_langs = ["en", "ru"]
        self._coeff_denoising = 1
        self._coeff_backtranslation = 1
        self._coeff_translation = 1

        self._label_smoothing = 0.1

        self._pad_index_lang1 = vocab.get_token_index(DEFAULT_PADDING_TOKEN,
                                                      self._lang1_namespace)
        self._oov_index_lang1 = vocab.get_token_index(DEFAULT_OOV_TOKEN,
                                                      self._lang1_namespace)
        self._end_index_lang1 = self.vocab.get_token_index(
            END_SYMBOL, self._lang1_namespace)

        self._pad_index_lang2 = vocab.get_token_index(DEFAULT_PADDING_TOKEN,
                                                      self._lang2_namespace)
        self._oov_index_lang2 = vocab.get_token_index(DEFAULT_OOV_TOKEN,
                                                      self._lang2_namespace)
        self._end_index_lang2 = self.vocab.get_token_index(
            END_SYMBOL, self._lang2_namespace)

        self._reader = dataset_reader
        self._langs_list = self._reader._langs_list
        self._ae_steps = self._reader._ae_steps
        self._bt_steps = self._reader._bt_steps
        self._para_steps = self._reader._para_steps

        if use_bleu:
            self._bleu = Average()
        else:
            self._bleu = None

        args = ArgsStub()

        transformer_iwslt_de_en(args)

        # build encoder
        if not hasattr(args, 'max_source_positions'):
            args.max_source_positions = 1024
        if not hasattr(args, 'max_target_positions'):
            args.max_target_positions = 1024

        # Dense embedding of source vocab tokens.
        self._source_embedder = source_embedder

        # Dense embedding of vocab words in the target space.
        num_tokens_lang1 = self.vocab.get_vocab_size(self._lang1_namespace)
        num_tokens_lang2 = self.vocab.get_vocab_size(self._lang2_namespace)

        args.share_decoder_input_output_embed = False  # TODO implement shared embeddings

        lang1_dict = DictStub(num_tokens=num_tokens_lang1,
                              pad=self._pad_index_lang1,
                              unk=self._oov_index_lang1,
                              eos=self._end_index_lang1)

        lang2_dict = DictStub(num_tokens=num_tokens_lang2,
                              pad=self._pad_index_lang2,
                              unk=self._oov_index_lang2,
                              eos=self._end_index_lang2)

        # instantiate fairseq classes
        emb_golden_tokens = FairseqEmbedding(num_tokens_lang2,
                                             args.decoder_embed_dim,
                                             self._pad_index_lang2)

        self._encoder = TransformerEncoder(args, lang1_dict,
                                           self._source_embedder)
        self._decoder = TransformerDecoder(args, lang2_dict, emb_golden_tokens)
        self._model = TransformerModel(self._encoder, self._decoder)

        # TODO: do not hardcode max_len_b and beam size
        self._sequence_generator_greedy = FairseqBeamSearchWrapper(
            SequenceGenerator(tgt_dict=lang2_dict, beam_size=1, max_len_b=20))
        self._sequence_generator_beam = FairseqBeamSearchWrapper(
            SequenceGenerator(tgt_dict=lang2_dict, beam_size=7, max_len_b=20))

    @overrides
    def forward(
        self,  # type: ignore
        lang_pair: List[str],
        lang1_tokens: Dict[str, torch.LongTensor] = None,
        lang1_golden: Dict[str, torch.LongTensor] = None,
        lang2_tokens: Dict[str, torch.LongTensor] = None,
        lang2_golden: Dict[str, torch.LongTensor] = None
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        """
        # detect training mode and what kind of task we need to compute
        if lang2_tokens is None and lang1_tokens is None:
            raise ConfigurationError(
                "source_tokens and target_tokens can not both be None")

        mode_training = self.training
        mode_validation = not self.training and lang2_tokens is not None  # change 'target_tokens' condition
        mode_prediction = lang2_tokens is None  # change 'target_tokens' condition

        lang_src, lang_tgt = lang_pair[0].split('-')

        if mode_training:
            # task types
            task_translation = False
            task_denoising = False
            task_backtranslation = False

            if lang_src == 'xx':
                task_backtranslation = True
            elif lang_src == lang_tgt:
                task_denoising = True
            elif lang_src != lang_tgt:
                task_translation = True
            else:
                raise ConfigurationError("All tasks are false")

        output_dict = {}
        if mode_training:

            if task_translation:
                loss = self._forward_seq2seq(lang_pair, lang1_tokens,
                                             lang2_tokens, lang2_golden)
                if self._bleu:
                    predicted_indices = self._sequence_generator_beam.generate(
                        [self._model], lang1_tokens,
                        self._get_true_pad_mask(lang1_tokens),
                        self._end_index_lang2)
                    predicted_strings = self._indices_to_strings(
                        predicted_indices)
                    golden_strings = self._indices_to_strings(
                        lang2_tokens["tokens"])
                    golden_strings = self._remove_pad_eos(golden_strings)
                    # print(golden_strings, predicted_strings)
                    self._bleu(corpus_bleu(golden_strings, predicted_strings))
            elif task_denoising:  # might need to split it into two blocks for interlingua loss
                loss = self._forward_seq2seq(lang_pair, lang1_tokens,
                                             lang2_tokens, lang2_golden)
            elif task_backtranslation:
                # our goal is also to learn from regular cross-entropy loss, but since we do not have source tokens,
                # we will generate them ourselves with current model
                langs_src = self._backtranslation_src_langs.copy()
                langs_src.remove(lang_tgt)
                bt_losses = {}
                for lang_src in langs_src:
                    curr_lang_pair = lang_src + "-" + lang_tgt
                    # TODO: require to pass target language to forward on encoder outputs
                    # We use greedy decoder because it was shown better for backtranslation
                    with torch.no_grad():
                        predicted_indices = self._sequence_generator_greedy.generate(
                            [self._model], lang2_tokens,
                            self._get_true_pad_mask(lang2_tokens),
                            self._end_index_lang2)
                    model_input = self._strings_to_batch(
                        self._indices_to_strings(predicted_indices),
                        lang2_tokens, lang2_golden, curr_lang_pair)
                    bt_losses['bt:' + curr_lang_pair] = self._forward_seq2seq(
                        **model_input)
            else:
                raise ConfigurationError("No task have been detected")

            if task_translation:
                loss = self._coeff_translation * loss
            elif task_denoising:
                loss = self._coeff_denoising * loss
            elif task_backtranslation:
                loss = 0
                for bt_loss in bt_losses.values():
                    loss += self._coeff_backtranslation * bt_loss

            output_dict["loss"] = loss

        elif mode_validation:
            output_dict["loss"] = self._coeff_translation * \
                                  self._forward_seq2seq(lang_pair, lang1_tokens, lang2_tokens, lang2_golden)
            if self._bleu:
                predicted_indices = self._sequence_generator_greedy.generate(
                    [self._model], lang1_tokens,
                    self._get_true_pad_mask(lang1_tokens),
                    self._end_index_lang2)
                predicted_strings = self._indices_to_strings(predicted_indices)
                golden_strings = self._indices_to_strings(
                    lang2_tokens["tokens"])
                golden_strings = self._remove_pad_eos(golden_strings)
                print(golden_strings, predicted_strings)
                self._bleu(corpus_bleu(golden_strings, predicted_strings))

        elif mode_prediction:
            # TODO: pass target language (in the fseq_encoder append embedded target language to the encoder out)
            predicted_indices = self._sequence_generator_beam.generate(
                [self._model], lang1_tokens,
                self._get_true_pad_mask(lang1_tokens), self._end_index_lang2)
            output_dict["predicted_indices"] = predicted_indices
            output_dict["predicted_strings"] = self._indices_to_strings(
                predicted_indices)

        return output_dict

    def _get_true_pad_mask(self, indexed_input):
        mask = util.get_text_field_mask(indexed_input)
        # TODO: account for cases when text field mask doesn't work, like BERT
        return mask

    def _remove_pad_eos(self, golden_strings):
        tmp = []
        for x in golden_strings:
            tmp.append(
                list(
                    filter(
                        lambda a: a != DEFAULT_PADDING_TOKEN and a !=
                        END_SYMBOL, x)))
        return tmp

    def _convert_to_sentences(self, golden_strings, predicted_strings):
        golden_strings_nopad = []
        for s in golden_strings:
            s_nopad = list(filter(lambda t: t != DEFAULT_PADDING_TOKEN, s))
            s_nopad = " ".join(s_nopad)
            golden_strings_nopad.append(s_nopad)
        predicted_strings = [" ".join(s) for s in predicted_strings]
        return golden_strings_nopad, predicted_strings

    def _forward_seq2seq(
            self, lang_pair: List[str], source_tokens: Dict[str,
                                                            torch.LongTensor],
            target_tokens: Dict[str, torch.LongTensor],
            target_golden: Dict[str,
                                torch.LongTensor]) -> Dict[str, torch.Tensor]:
        source_tokens_padding_mask = self._get_true_pad_mask(source_tokens)
        encoder_out = self._encoder.forward(source_tokens,
                                            source_tokens_padding_mask)
        logits, _ = self._decoder.forward(target_tokens["tokens"], encoder_out)
        loss = self._get_ce_loss(logits, target_golden)
        return loss

    def _get_ce_loss(self, logits, golden):
        target_mask = util.get_text_field_mask(golden)
        loss = util.sequence_cross_entropy_with_logits(
            logits,
            golden["golden_tokens"],
            target_mask,
            label_smoothing=self._label_smoothing)
        return loss

    def _indices_to_strings(self, indices: torch.Tensor):
        all_predicted_tokens = []
        for hyp in indices:
            predicted_tokens = [
                self.vocab.get_token_from_index(
                    idx.item(), namespace=self._lang2_namespace) for idx in hyp
            ]
            all_predicted_tokens.append(predicted_tokens)
        return all_predicted_tokens

    def _strings_to_batch(self, source_tokens: List[List[str]],
                          target_tokens: Dict[str, torch.Tensor],
                          target_golden: Dict[str,
                                              torch.Tensor], lang_pair: str):
        """
        Converts list of sentences which are itself lists of strings into Batch
        suitable for passing into model's forward function.

        TODO: Make sure the right device (CPU/GPU) is used. Predicted tokens might get copied on
        CPU in `self.decode` method...
        """
        # convert source tokens into source tensor_dict
        instances = []
        lang_pairs = []
        for sentence in source_tokens:
            sentence = " ".join(sentence)
            instances.append(self._reader.string_to_instance(sentence))
            lang_pairs.append(lang_pair)

        source_batch = Batch(instances)
        source_batch.index_instances(self.vocab)
        source_batch = source_batch.as_tensor_dict()
        model_input = {
            "source_tokens": source_batch["tokens"],
            "target_golden": target_golden,
            "target_tokens": target_tokens,
            "lang_pair": lang_pairs
        }

        return model_input

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if self._bleu and not self.training:
            all_metrics.update({"BLEU": self._bleu.get_metric(reset=reset)})
        return all_metrics
コード例 #7
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 cnn_size: int = 100,
                 dropout_weight: float = 0.1,
                 with_entity_embeddings: bool = True,
                 sent_loss_weight: float = 1,
                 attention_weight_fn: str = 'sigmoid',
                 attention_aggregation_fn: str = 'max') -> None:
        regularizer = None
        super().__init__(vocab, regularizer)
        self.num_classes = self.vocab.get_vocab_size("labels")

        self.text_field_embedder = text_field_embedder
        self.dropout_weight = dropout_weight
        self.with_entity_embeddings = with_entity_embeddings
        self.sent_loss_weight = sent_loss_weight
        self.attention_weight_fn = attention_weight_fn
        self.attention_aggregation_fn = attention_aggregation_fn

        # instantiate position embedder
        pos_embed_output_size = 5
        pos_embed_input_size = 2 * RelationInstancesReader.max_distance + 1
        self.pos_embed = nn.Embedding(pos_embed_input_size,
                                      pos_embed_output_size)
        pos_embed_weights = np.array([range(pos_embed_input_size)] *
                                     pos_embed_output_size).T
        self.pos_embed.weight = nn.Parameter(torch.Tensor(pos_embed_weights))

        d = cnn_size
        sent_encoder = CnnEncoder  # TODO: should be moved to the config file
        cnn_output_size = d
        embedding_size = 300  # TODO: should be moved to the config file

        # instantiate sentence encoder
        self.cnn = sent_encoder(embedding_dim=(embedding_size +
                                               2 * pos_embed_output_size),
                                num_filters=cnn_size,
                                ngram_filter_sizes=(2, 3, 4, 5),
                                conv_layer_activation=torch.nn.ReLU(),
                                output_dim=cnn_output_size)

        # dropout after word embedding
        self.dropout = nn.Dropout(p=self.dropout_weight)

        #  given a sentence, returns its unnormalized attention weight
        self.attention_ff = nn.Sequential(nn.Linear(cnn_output_size, d),
                                          nn.ReLU(), nn.Linear(d, 1))

        self.ff_before_alpha = nn.Sequential(
            nn.Linear(1, 50),
            nn.ReLU(),
            nn.Linear(50, 1),
        )

        ff_input_size = cnn_output_size
        if self.with_entity_embeddings:
            ff_input_size += embedding_size

        # output layer
        self.ff = nn.Sequential(nn.Linear(ff_input_size, d), nn.ReLU(),
                                nn.Linear(d, self.num_classes))

        self.loss = torch.nn.BCEWithLogitsLoss(
        )  # sigmoid + binary cross entropy
        self.metrics = {}
        self.metrics['ap'] = MultilabelAveragePrecision(
        )  # average precision = AUC
        self.metrics['bag_loss'] = Average()  # to display bag-level loss

        if self.sent_loss_weight > 0:
            self.metrics['sent_loss'] = Average(
            )  # to display sentence-level loss
コード例 #8
0
ファイル: bert_qa.py プロジェクト: pombredanne/UrcaNet
class BertQA(Model):
    """
    This class implements Minjoon Seo's `Bidirectional Attention Flow model
    <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_
    for answering reading comprehension questions (ICLR 2017).

    The basic layout is pretty simple: encode words as a combination of word embeddings and a
    character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of
    attentions to put question information into the passage word representations (this is the only
    part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and
    do a softmax over span start and span end.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    num_highway_layers : ``int``
        The number of highway layers to use in between embedding the input and passing it through
        the phrase layer.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    similarity_function : ``SimilarityFunction``
        The similarity function that we will use when comparing encoded passage and question
        representations.
    modeling_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between the bidirectional
        attention and predicting span start and end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    mask_lstms : ``bool``, optional (default=True)
        If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
        with only a slight performance decrease, if any.  We haven't experimented much with this
        yet, but have confirmed that we still get very similar performance with much faster
        training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
        required when using masking with pytorch LSTMs.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 sim_text_field_embedder: TextFieldEmbedder,
                 loss_weights: Dict,
                 sim_class_weights: List,
                 pretrained_sim_path: str = None,
                 use_scenario_encoding: bool = True,
                 sim_pretraining: bool = False,
                 dropout: float = 0.2,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BertQA, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        if use_scenario_encoding:
            self._sim_text_field_embedder = sim_text_field_embedder
        self.loss_weights = loss_weights
        self.sim_class_weights = sim_class_weights
        self.use_scenario_encoding = use_scenario_encoding
        self.sim_pretraining = sim_pretraining

        if self.sim_pretraining and not self.use_scenario_encoding:
            raise ValueError(
                "When pretraining Scenario Interpretation Module, you should use it."
            )

        embedding_dim = self._text_field_embedder.get_output_dim()
        self._action_predictor = torch.nn.Linear(embedding_dim, 4)
        self._sim_token_label_predictor = torch.nn.Linear(embedding_dim, 4)
        self._span_predictor = torch.nn.Linear(embedding_dim, 2)
        self._action_accuracy = CategoricalAccuracy()
        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        self._span_loss_metric = Average()
        self._action_loss_metric = Average()
        self._sim_loss_metric = Average()
        self._sim_yes_f1 = F1Measure(2)
        self._sim_no_f1 = F1Measure(3)

        if use_scenario_encoding and pretrained_sim_path is not None:
            logger.info("Loading pretrained model..")
            self.load_state_dict(torch.load(pretrained_sim_path))
            for param in self._sim_text_field_embedder.parameters():
                param.requires_grad = False

        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x

        initializer(self)

    def get_passage_representation(self, bert_output, bert_input):
        # Shape: (batch_size, bert_input_len)
        input_type_ids = self.get_input_type_ids(
            bert_input['bert-type-ids'], bert_input['bert-offsets'],
            self._text_field_embedder._token_embedders['bert']).float()
        # Shape: (batch_size, bert_input_len)
        input_mask = util.get_text_field_mask(bert_input).float()
        passage_mask = input_mask - input_type_ids  # works only with one [SEP]
        # Shape: (batch_size, bert_input_len, embedding_dim)
        passage_representation = bert_output * passage_mask.unsqueeze(2)
        # Shape: (batch_size, passage_len, embedding_dim)
        passage_representation = passage_representation[:,
                                                        passage_mask.sum(
                                                            dim=0) > 0, :]
        # Shape: (batch_size, passage_len)
        passage_mask = passage_mask[:, passage_mask.sum(dim=0) > 0]

        return passage_representation, passage_mask

    def forward(
            self,  # type: ignore
            bert_input: Dict[str, torch.LongTensor],
            sim_bert_input: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None,
            label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """

        if self.use_scenario_encoding:
            # Shape: (batch_size, sim_bert_input_len_wp)
            sim_bert_input_token_labels_wp = sim_bert_input[
                'scenario_gold_encoding']
            # Shape: (batch_size, sim_bert_input_len_wp, embedding_dim)
            sim_bert_output_wp = self._sim_text_field_embedder(sim_bert_input)
            # Shape: (batch_size, sim_bert_input_len_wp)
            sim_input_mask_wp = (sim_bert_input['bert'] != 0).float()
            # Shape: (batch_size, sim_bert_input_len_wp)
            sim_passage_mask_wp = sim_input_mask_wp - sim_bert_input[
                'bert-type-ids'].float()  # works only with one [SEP]
            # Shape: (batch_size, sim_bert_input_len_wp, embedding_dim)
            sim_passage_representation_wp = sim_bert_output_wp * sim_passage_mask_wp.unsqueeze(
                2)
            # Shape: (batch_size, passage_len_wp, embedding_dim)
            sim_passage_representation_wp = sim_passage_representation_wp[:,
                                                                          sim_passage_mask_wp
                                                                          .sum(
                                                                              dim
                                                                              =0
                                                                          ) >
                                                                          0, :]
            # Shape: (batch_size, passage_len_wp)
            sim_passage_token_labels_wp = sim_bert_input_token_labels_wp[:,
                                                                         sim_passage_mask_wp
                                                                         .sum(
                                                                             dim
                                                                             =0
                                                                         ) > 0]
            # Shape: (batch_size, passage_len_wp)
            sim_passage_mask_wp = sim_passage_mask_wp[:,
                                                      sim_passage_mask_wp.sum(
                                                          dim=0) > 0]

            # Shape: (batch_size, passage_len_wp, 4)
            sim_token_logits_wp = self._sim_token_label_predictor(
                sim_passage_representation_wp)

            if span_start is not None:  # during training and validation
                class_weights = torch.tensor(self.sim_class_weights,
                                             device=sim_token_logits_wp.device,
                                             dtype=torch.float)
                sim_loss = cross_entropy(sim_token_logits_wp.view(-1, 4),
                                         sim_passage_token_labels_wp.view(-1),
                                         ignore_index=0,
                                         weight=class_weights)
                self._sim_loss_metric(sim_loss.item())
                self._sim_yes_f1(sim_token_logits_wp,
                                 sim_passage_token_labels_wp,
                                 sim_passage_mask_wp)
                self._sim_no_f1(sim_token_logits_wp,
                                sim_passage_token_labels_wp,
                                sim_passage_mask_wp)
                if self.sim_pretraining:
                    return {'loss': sim_loss}

            if not self.sim_pretraining:
                # Shape: (batch_size, passage_len_wp)
                bert_input['scenario_encoding'] = (sim_token_logits_wp.argmax(
                    dim=2)) * sim_passage_mask_wp.long()
                # Shape: (batch_size, bert_input_len_wp)
                bert_input_wp_len = bert_input['history_encoding'].size(1)
                if bert_input['scenario_encoding'].size(1) > bert_input_wp_len:
                    # Shape: (batch_size, bert_input_len_wp)
                    bert_input['scenario_encoding'] = bert_input[
                        'scenario_encoding'][:, :bert_input_wp_len]
                else:
                    batch_size = bert_input['scenario_encoding'].size(0)
                    difference = bert_input_wp_len - bert_input[
                        'scenario_encoding'].size(1)
                    zeros = torch.zeros(
                        batch_size,
                        difference,
                        dtype=bert_input['scenario_encoding'].dtype,
                        device=bert_input['scenario_encoding'].device)
                    # Shape: (batch_size, bert_input_len_wp)
                    bert_input['scenario_encoding'] = torch.cat(
                        [bert_input['scenario_encoding'], zeros], dim=1)

        # Shape: (batch_size, bert_input_len + 1, embedding_dim)
        bert_output = self._text_field_embedder(bert_input)
        # Shape: (batch_size, embedding_dim)
        pooled_output = bert_output[:, 0]
        # Shape: (batch_size, bert_input_len, embedding_dim)
        bert_output = bert_output[:, 1:, :]
        # Shape: (batch_size, passage_len, embedding_dim), (batch_size, passage_len)
        passage_representation, passage_mask = self.get_passage_representation(
            bert_output, bert_input)

        # Shape: (batch_size, 4)
        action_logits = self._action_predictor(pooled_output)
        # Shape: (batch_size, passage_len, 2)
        span_logits = self._span_predictor(passage_representation)
        # Shape: (batch_size, passage_len, 1), (batch_size, passage_len, 1)
        span_start_logits, span_end_logits = span_logits.split(1, dim=2)
        # Shape: (batch_size, passage_len)
        span_start_logits = span_start_logits.squeeze(2)
        # Shape: (batch_size, passage_len)
        span_end_logits = span_end_logits.squeeze(2)

        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "pooled_output": pooled_output,
            "passage_representation": passage_representation,
            "action_logits": action_logits,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        if self.use_scenario_encoding:
            output_dict["sim_token_logits"] = sim_token_logits_wp

        # Compute the loss for training (and for validation)
        if span_start is not None:
            # Shape: (batch_size,)
            span_loss = nll_loss(util.masked_log_softmax(
                span_start_logits, passage_mask),
                                 span_start.squeeze(1),
                                 reduction='none')
            # Shape: (batch_size,)
            span_loss += nll_loss(util.masked_log_softmax(
                span_end_logits, passage_mask),
                                  span_end.squeeze(1),
                                  reduction='none')
            # Shape: (batch_size,)
            more_mask = (label == self.vocab.get_token_index(
                'More', namespace="labels")).float()
            # Shape: (batch_size,)
            span_loss = (span_loss * more_mask).sum() / (more_mask.sum() +
                                                         1e-6)
            if more_mask.sum() > 1e-7:
                self._span_start_accuracy(span_start_logits,
                                          span_start.squeeze(1), more_mask)
                self._span_end_accuracy(span_end_logits, span_end.squeeze(1),
                                        more_mask)
                # Shape: (batch_size, 2)
                span_acc_mask = more_mask.unsqueeze(1).expand(-1, 2).long()
                self._span_accuracy(best_span,
                                    torch.cat([span_start, span_end], dim=1),
                                    span_acc_mask)

            action_loss = cross_entropy(action_logits, label)
            self._action_accuracy(action_logits, label)

            self._span_loss_metric(span_loss.item())
            self._action_loss_metric(action_loss.item())
            output_dict['loss'] = self.loss_weights[
                'span_loss'] * span_loss + self.loss_weights[
                    'action_loss'] * action_loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if not self.training:  # true during validation and test
            output_dict['best_span_str'] = []
            batch_size = len(metadata)
            for i in range(batch_size):
                passage_text = metadata[i]['passage_text']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_str = passage_text[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_str)
                if 'gold_span' in metadata[i]:
                    if metadata[i]['action'] == 'More':
                        gold_span = metadata[i]['gold_span']
                        self._squad_metrics(best_span_str, [gold_span])
        return output_dict

    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        action_probs = softmax(output_dict['action_logits'], dim=1)
        output_dict['action_probs'] = action_probs

        predictions = action_probs.cpu().data.numpy()
        argmax_indices = numpy.argmax(predictions, axis=1)
        labels = [
            self.vocab.get_token_from_index(x, namespace="labels")
            for x in argmax_indices
        ]
        output_dict['label'] = labels
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        if self.use_scenario_encoding:
            sim_loss = self._sim_loss_metric.get_metric(reset)
            _, _, yes_f1 = self._sim_yes_f1.get_metric(reset)
            _, _, no_f1 = self._sim_no_f1.get_metric(reset)

        if self.sim_pretraining:
            return {'sim_macro_f1': (yes_f1 + no_f1) / 2}

        try:
            action_acc = self._action_accuracy.get_metric(reset)
        except ZeroDivisionError:
            action_acc = 0
        try:
            start_acc = self._span_start_accuracy.get_metric(reset)
        except ZeroDivisionError:
            start_acc = 0
        try:
            end_acc = self._span_end_accuracy.get_metric(reset)
        except ZeroDivisionError:
            end_acc = 0
        try:
            span_acc = self._span_accuracy.get_metric(reset)
        except ZeroDivisionError:
            span_acc = 0

        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        span_loss = self._span_loss_metric.get_metric(reset)
        action_loss = self._action_loss_metric.get_metric(reset)
        agg_metric = span_acc + action_acc * 0.45

        metrics = {
            'action_acc': action_acc,
            'span_acc': span_acc,
            'span_loss': span_loss,
            'action_loss': action_loss,
            'agg_metric': agg_metric
        }

        if self.use_scenario_encoding:
            metrics['sim_macro_f1'] = (yes_f1 + no_f1) / 2

        if not self.training:  # during validation
            metrics['em'] = exact_match
            metrics['f1'] = f1_score

        return metrics

    @staticmethod
    def get_best_span(span_start_logits: torch.Tensor,
                      span_end_logits: torch.Tensor) -> torch.Tensor:
        # We call the inputs "logits" - they could either be unnormalized logits or normalized log
        # probabilities.  A log_softmax operation is a constant shifting of the entire logit
        # vector, so taking an argmax over either one gives the same result.
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        device = span_start_logits.device
        # (batch_size, passage_length, passage_length)
        span_log_probs = span_start_logits.unsqueeze(
            2) + span_end_logits.unsqueeze(1)
        # Only the upper triangle of the span matrix is valid; the lower triangle has entries where
        # the span ends before it starts.
        span_log_mask = torch.triu(
            torch.ones((passage_length, passage_length),
                       device=device)).log().unsqueeze(0)
        valid_span_log_probs = span_log_probs + span_log_mask

        # Here we take the span matrix and flatten it, then find the best span using argmax.  We
        # can recover the start and end indices from this flattened list using simple modular
        # arithmetic.
        # (batch_size, passage_length * passage_length)
        best_spans = valid_span_log_probs.view(batch_size, -1).argmax(-1)
        span_start_indices = best_spans // passage_length
        span_end_indices = best_spans % passage_length
        return torch.stack([span_start_indices, span_end_indices], dim=-1)

    def get_input_type_ids(self, type_ids, offsets, embedder):
        "Converts (bsz, seq_len_wp) to (bsz, seq_len_wp) by indexing."
        batch_size = type_ids.size(0)
        full_seq_len = type_ids.size(1)
        if full_seq_len > embedder.max_pieces:  # Recombine if we had used sliding window approach
            assert batch_size == 1 and type_ids.max() > 0
            num_question_tokens = type_ids[0][:embedder.max_pieces].nonzero(
            ).size(0)
            select_indices = embedder.indices_to_select(
                full_seq_len, num_question_tokens)
            type_ids = type_ids[:, select_indices]

        range_vector = util.get_range_vector(
            batch_size, device=util.get_device_of(type_ids)).unsqueeze(1)
        type_ids = type_ids[range_vector, offsets]
        return type_ids
コード例 #9
0
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 max_decoding_steps: int,
                 attention: Attention,
                 schema_path: str = None,
                 missing_alignment_int: int = 0,
                 indexfield_padding_index: int = -1,
                 beam_size: int = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None,
                 scheduled_sampling_ratio: float = 0.,
                 use_bleu: bool = True,
                 emb_dropout: float = 0.0,
                 dec_dropout: float = 0.0,
                 attn_loss_lambda: float = 0.5,
                 token_based_metric: Metric = None) -> None:
        super(AttnSupSeq2Seq, self).__init__(vocab)
        self._target_namespace = target_namespace
        self._scheduled_sampling_ratio = scheduled_sampling_ratio
        self._indexfield_padding_index = indexfield_padding_index
        self._missing_alignment_int = missing_alignment_int

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)

        if use_bleu:
            pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                   self._target_namespace)  # pylint: disable=protected-access
            self._bleu = BLEU(exclude_indices={
                pad_index, self._end_index, self._start_index
            })
        else:
            self._bleu = None

        if token_based_metric:
            self._token_based_metric = token_based_metric
        else:
            self._token_based_metric = TokenSequenceAccuracy()
        # log attention supervision CE loss as a metric
        self._attn_sup_loss = Average()
        self._sql_metrics = schema_path is not None
        if self._sql_metrics:
            # SQL specific metrics: match between the templates free of schema constants,
            # and match between the schema constants
            self._schema_free_match = GlobalTemplAccuracy(
                schema_path=schema_path)
            self._kb_match = KnowledgeBaseConstsAccuracy(
                schema_path=schema_path)

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size)

        # Dense embedding of source vocab tokens.
        self._source_embedder = source_embedder
        self._emb_dropout = Dropout(p=emb_dropout)
        self._dec_dropout = Dropout(p=dec_dropout)
        self._attn_loss_lambda = attn_loss_lambda
        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._encoder = encoder

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        # Attention mechanism applied to the encoder output for each step.
        self._attention = attention
        self._attention._normalize = False

        # Dense embedding of vocab words in the target space.
        target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim(
        )
        self._target_embedder = Embedding(num_classes, target_embedding_dim)

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        self._encoder_output_dim = self._encoder.get_output_dim()
        self._decoder_output_dim = self._encoder_output_dim

        # A weighted average over encoder outputs will be concatenated to the previous target embedding
        # to form the input to the decoder at each time step.
        self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim

        # We'll use an LSTM cell as the recurrent cell that produces a hidden state
        # for the decoder at each time step.
        # TODO (pradeep): Do not hardcode decoder cell type.
        self._decoder_cell = LSTMCell(self._decoder_input_dim,
                                      self._decoder_output_dim)

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_output_dim,
                                               num_classes)
コード例 #10
0
class AttnSupSeq2Seq(Model):
    """
    Adaptation of the ``SimpleSeq2Seq`` class in allennlp_models, with auxiliary attention-supervision loss

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        Vocabulary containing source and target vocabularies. They may be under the same namespace
        (`tokens`) or the target tokens can have a different namespace, in which case it needs to
        be specified as `target_namespace`.
    source_embedder : ``TextFieldEmbedder``, required
        Embedder for source side sequences
    encoder : ``Seq2SeqEncoder``, required
        The encoder of the "encoder/decoder" model
    max_decoding_steps : ``int``
        Maximum length of decoded sequences.
    target_namespace : ``str``, optional (default = 'target_tokens')
        If the target side vocabulary is different from the source side's, you need to specify the
        target's namespace here. If not, we'll assume it is "tokens", which is also the default
        choice for the source side, and this might cause them to share vocabularies.
    target_embedding_dim : ``int``, optional (default = source_embedding_dim)
        You can specify an embedding dimensionality for the target side. If not, we'll use the same
        value as the source embedder's.
    attention : ``Attention``, optional (default = None)
        If you want to use attention to get a dynamic summary of the encoder outputs at each step
        of decoding, this is the function used to compute similarity between the decoder hidden
        state and encoder outputs.
    attention_function: ``SimilarityFunction``, optional (default = None)
        This is if you want to use the legacy implementation of attention. This will be deprecated
        since it consumes more memory than the specialized attention modules.
    beam_size : ``int``, optional (default = None)
        Width of the beam for beam search. If not specified, greedy decoding is used.
    scheduled_sampling_ratio : ``float``, optional (default = 0.)
        At each timestep during training, we sample a random number between 0 and 1, and if it is
        not less than this value, we use the ground truth labels for the whole batch. Else, we use
        the predictions from the previous time step for the whole batch. If this value is 0.0
        (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not
        using target side ground truth labels.  See the following paper for more information:
        `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al.,
        2015 <https://arxiv.org/abs/1506.03099>`_.
    use_bleu : ``bool``, optional (default = True)
        If True, the BLEU metric will be calculated during validation.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 max_decoding_steps: int,
                 attention: Attention,
                 schema_path: str = None,
                 missing_alignment_int: int = 0,
                 indexfield_padding_index: int = -1,
                 beam_size: int = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None,
                 scheduled_sampling_ratio: float = 0.,
                 use_bleu: bool = True,
                 emb_dropout: float = 0.0,
                 dec_dropout: float = 0.0,
                 attn_loss_lambda: float = 0.5,
                 token_based_metric: Metric = None) -> None:
        super(AttnSupSeq2Seq, self).__init__(vocab)
        self._target_namespace = target_namespace
        self._scheduled_sampling_ratio = scheduled_sampling_ratio
        self._indexfield_padding_index = indexfield_padding_index
        self._missing_alignment_int = missing_alignment_int

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)

        if use_bleu:
            pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                   self._target_namespace)  # pylint: disable=protected-access
            self._bleu = BLEU(exclude_indices={
                pad_index, self._end_index, self._start_index
            })
        else:
            self._bleu = None

        if token_based_metric:
            self._token_based_metric = token_based_metric
        else:
            self._token_based_metric = TokenSequenceAccuracy()
        # log attention supervision CE loss as a metric
        self._attn_sup_loss = Average()
        self._sql_metrics = schema_path is not None
        if self._sql_metrics:
            # SQL specific metrics: match between the templates free of schema constants,
            # and match between the schema constants
            self._schema_free_match = GlobalTemplAccuracy(
                schema_path=schema_path)
            self._kb_match = KnowledgeBaseConstsAccuracy(
                schema_path=schema_path)

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size)

        # Dense embedding of source vocab tokens.
        self._source_embedder = source_embedder
        self._emb_dropout = Dropout(p=emb_dropout)
        self._dec_dropout = Dropout(p=dec_dropout)
        self._attn_loss_lambda = attn_loss_lambda
        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._encoder = encoder

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        # Attention mechanism applied to the encoder output for each step.
        self._attention = attention
        self._attention._normalize = False

        # Dense embedding of vocab words in the target space.
        target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim(
        )
        self._target_embedder = Embedding(num_classes, target_embedding_dim)

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        self._encoder_output_dim = self._encoder.get_output_dim()
        self._decoder_output_dim = self._encoder_output_dim

        # A weighted average over encoder outputs will be concatenated to the previous target embedding
        # to form the input to the decoder at each time step.
        self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim

        # We'll use an LSTM cell as the recurrent cell that produces a hidden state
        # for the decoder at each time step.
        # TODO (pradeep): Do not hardcode decoder cell type.
        self._decoder_cell = LSTMCell(self._decoder_input_dim,
                                      self._decoder_output_dim)

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_output_dim,
                                               num_classes)

    def take_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take a decoding step. This is called by the beam search class.

        Parameters
        ----------
        last_predictions : ``torch.Tensor``
            A tensor of shape ``(group_size,)``, which gives the indices of the predictions
            during the last time step.
        state : ``Dict[str, torch.Tensor]``
            A dictionary of tensors that contain the current state information
            needed to predict the next step, which includes the encoder outputs,
            the source mask, and the decoder hidden state and context. Each of these
            tensors has shape ``(group_size, *)``, where ``*`` can be any other number
            of dimensions.

        Returns
        -------
        Tuple[torch.Tensor, Dict[str, torch.Tensor]]
            A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities``
            is a tensor of shape ``(group_size, num_classes)`` containing the predicted
            log probability of each class for the next step, for each item in the group,
            while ``updated_state`` is a dictionary of tensors containing the encoder outputs,
            source mask, and updated decoder hidden state and context.

        Notes
        -----
            We treat the inputs as a batch, even though ``group_size`` is not necessarily
            equal to ``batch_size``, since the group may contain multiple states
            for each source sentence in the batch.
        """
        # shape: (group_size, num_classes)
        _, output_projections, state = self._prepare_output_projections(
            last_predictions, state)

        # shape: (group_size, num_classes)
        class_log_probabilities = F.log_softmax(output_projections, dim=-1)

        return class_log_probabilities, state

    @overrides
    def forward_on_instances(
            self, instances: List[Instance]) -> List[Dict[str, numpy.ndarray]]:
        """
        Takes a list of  :class:`~allennlp.data.instance.Instance`s, converts that text into
        arrays using this model's :class:`Vocabulary`, passes those arrays through
        :func:`self.forward()` and :func:`self.decode()` (which by default does nothing)
        and returns the result.  Before returning the result, we convert any
        ``torch.Tensors`` into numpy arrays and separate the
        batched output into a list of individual dicts per instance. Note that typically
        this will be faster on a GPU (and conditionally, on a CPU) than repeated calls to
        :func:`forward_on_instance`.

        Parameters
        ----------
        instances : List[Instance], required
            The instances to run the model on.
        cuda_device : int, required
            The GPU device to use.  -1 means use the CPU.

        Returns
        -------
        A list of the models output for each instance.
        """
        batch_size = len(instances)
        with torch.no_grad():
            cuda_device = self._get_prediction_device()
            dataset = Batch(instances)
            dataset.index_instances(self.vocab)
            model_input = util.move_to_device(dataset.as_tensor_dict(),
                                              cuda_device)
            outputs = self.decode(self(**model_input))

            instance_separated_output: List[Dict[str, numpy.ndarray]] = [
                {} for _ in dataset.instances
            ]
            for name, output in list(outputs.items()):
                if isinstance(output, torch.Tensor):
                    # NOTE(markn): This is a hack because 0-dim pytorch tensors are not iterable.
                    # This occurs with batch size 1, because we still want to include the loss in that case.
                    if output.dim() == 0:
                        output = output.unsqueeze(0)

                    if output.size(0) != batch_size:
                        self._maybe_warn_for_unseparable_batches(name)
                        continue
                    output = output.detach().cpu().numpy()
                elif len(output) != batch_size:
                    self._maybe_warn_for_unseparable_batches(name)
                    continue
                for instance_output, batch_element in zip(
                        instance_separated_output, output):
                    instance_output[name] = batch_element

            for instance_output, instance_input in zip(
                    instance_separated_output, instances):
                for field in instance_input.fields:
                    try:
                        instance_output[field] = instance_input.fields[
                            field].tokens
                    except Exception as e:
                        continue

            return instance_separated_output

    @overrides
    def forward(
            self,  # type: ignore
            source_tokens: Dict[str, torch.LongTensor],
            target_tokens: Dict[str, torch.LongTensor] = None,
            alignment_sequence: torch.Tensor = None
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Make foward pass with decoder logic for producing the entire target sequence.

        Parameters
        ----------
        source_tokens : ``Dict[str, torch.LongTensor]``
           The output of `TextField.as_array()` applied on the source `TextField`. This will be
           passed through a `TextFieldEmbedder` and then through an encoder.
        target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None)
           Output of `Textfield.as_array()` applied on target `TextField`. We assume that the
           target tokens are also represented as a `TextField`.
        alignment_sequence : ``Dict[str, torch.LongTensor]``, optional (default = None)
            Output of `Textfield.as_array()` applied on alignemnet `TextField`.
        Returns
        -------
        Dict[str, torch.Tensor]
        """
        state = self._encode(source_tokens)

        if target_tokens:
            state = self._init_decoder_state(state)

            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            alignment_sequence = alignment_sequence.squeeze(-1)

            # The `_forward_loop` decodes the input sequence and computes the loss during training
            # and validation.
            output_dict = self._forward_loop(state, target_tokens,
                                             alignment_sequence)
        else:
            output_dict = {}

        if not self.training:
            state = self._init_decoder_state(state)
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)
            if target_tokens:
                if self._bleu:
                    # shape: (batch_size, beam_size, max_sequence_length)
                    top_k_predictions = output_dict["predictions"]
                    # shape: (batch_size, max_predicted_sequence_length)
                    best_predictions = top_k_predictions[:, 0, :]
                    self._bleu(best_predictions, target_tokens["tokens"])

                predicted_tokens = self.decode(output_dict)["predicted_tokens"]
                target_tokens_str = self.decode_target_tokens(target_tokens)

                if self._token_based_metric:
                    self._token_based_metric(predicted_tokens,
                                             target_tokens_str)
                if self._sql_metrics:
                    self._kb_match(predicted_tokens, target_tokens_str)
                    self._schema_free_match(predicted_tokens,
                                            target_tokens_str)

        # In case of attention coverage mechanism, reset the coverage vector after every batch...
        try:
            self._attention.reset_coverage_vector()
        except Exception:
            pass

        return output_dict

    def decode_target_tokens(self, target_tokens):
        target_indices = target_tokens['tokens'].detach().cpu().numpy()
        target_tokens_output = []
        for i in range(target_indices.shape[0]):
            cur_target_indices = target_indices[i]
            cur_target_indices = list(cur_target_indices)
            if self._end_index in cur_target_indices:
                cur_target_indices = cur_target_indices[:cur_target_indices.
                                                        index(self._end_index)]
            if self._start_index in cur_target_indices:
                cur_target_indices = cur_target_indices[
                    cur_target_indices.index(self._start_index) + 1:]
            target_tokens_str = [
                self.vocab.get_token_from_index(
                    x, namespace=self._target_namespace)
                for x in cur_target_indices
            ]
            target_tokens_output.append(target_tokens_str)

        return target_tokens_output

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Finalize predictions.

        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
        within the ``forward`` method.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            # Beam search gives us the top k results for each source sentence in the batch
            # but we just want the single best.
            if len(indices.shape) > 1:
                indices = indices[0]
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [
                self.vocab.get_token_from_index(
                    x, namespace=self._target_namespace) for x in indices
            ]
            all_predicted_tokens.append(predicted_tokens)
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict

    def _encode(
            self,
            source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
        embedded_input = self._source_embedder(source_tokens)
        # shape: (batch_size, max_input_sequence_length)
        source_mask = util.get_text_field_mask(source_tokens)
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = self._encoder(embedded_input, source_mask)
        encoder_outputs = self._emb_dropout(encoder_outputs)
        return {
            "source_mask": source_mask,
            "encoder_outputs": encoder_outputs,
        }

    def _init_decoder_state(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["source_mask"].size(0)
        # shape: (batch_size, encoder_output_dim)
        final_encoder_output = util.get_final_encoder_states(
            state["encoder_outputs"], state["source_mask"],
            self._encoder.is_bidirectional())
        # Initialize the decoder hidden state with the final output of the encoder.
        # shape: (batch_size, decoder_output_dim)
        state["decoder_hidden"] = final_encoder_output
        # shape: (batch_size, decoder_output_dim)
        state["decoder_context"] = state["encoder_outputs"].new_zeros(
            batch_size, self._decoder_output_dim)
        return state

    def _forward_loop(
            self,
            state: Dict[str, torch.Tensor],
            target_tokens: Dict[str, torch.LongTensor] = None,
            alignment_sequence: torch.Tensor = None
    ) -> Dict[str, torch.Tensor]:
        """
        Make forward pass during training or do greedy search during prediction.

        Notes
        -----
        We really only use the predictions from the method to test that beam search
        with a beam size of 1 gives the same results.
        """
        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        batch_size = source_mask.size()[0]

        if target_tokens:
            # shape: (batch_size, max_target_sequence_length)
            targets = target_tokens["tokens"]

            _, target_sequence_length = targets.size()

            # The last input from the target is either padding or the end symbol.
            # Either way, we don't have to process it.
            num_decoding_steps = target_sequence_length - 1
        else:
            num_decoding_steps = self._max_decoding_steps

        # Initialize target predictions with the start index.
        # shape: (batch_size,)
        last_predictions = source_mask.new_full((batch_size, ),
                                                fill_value=self._start_index)

        step_logits: List[torch.Tensor] = []
        step_predictions: List[torch.Tensor] = []
        step_attn_weights: List[torch.Tensor] = []
        for timestep in range(num_decoding_steps):
            if self.training and torch.rand(
                    1).item() < self._scheduled_sampling_ratio:
                # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
                # during training.
                # shape: (batch_size,)
                input_choices = last_predictions
            elif not target_tokens:
                # shape: (batch_size,)
                input_choices = last_predictions
            else:
                # shape: (batch_size,)
                input_choices = targets[:, timestep]

            # shape: (batch_size, num_classes)
            # shape: (batch_size, input_max_size)
            input_weights, output_projections, state = self._prepare_output_projections(
                input_choices, state)

            step_attn_weights.append(input_weights.unsqueeze(1))

            # list of tensors, shape: (batch_size, 1, num_classes)
            step_logits.append(output_projections.unsqueeze(1))

            # shape: (batch_size, num_classes)
            class_probabilities = F.softmax(output_projections, dim=-1)

            # shape (predicted_classes): (batch_size,)
            _, predicted_classes = torch.max(class_probabilities, 1)

            # shape (predicted_classes): (batch_size,)
            last_predictions = predicted_classes

            step_predictions.append(last_predictions.unsqueeze(1))

        # shape: (batch_size, num_decoding_steps)
        predictions = torch.cat(step_predictions, 1)

        # shape: (batch_size, num_decoding_steps, max_input_sequence_length)
        attention_input_weights = torch.cat(step_attn_weights[:-1], 1)

        output_dict = {
            "predictions": predictions,
            'attention_input_weights': attention_input_weights
        }

        if target_tokens:
            # shape: (batch_size, num_decoding_steps, num_classes)
            logits = torch.cat(step_logits, 1)

            # shape: (batch_size, num_decoding_steps, max_input_sequence_length)
            alignment_mask = self._get_alignment_mask(alignment_sequence)

            # Compute loss.
            target_mask = util.get_text_field_mask(target_tokens)
            loss = self._get_loss(logits, targets, target_mask)

            attn_sup_loss = self._get_attn_sup_loss(attention_input_weights,
                                                    alignment_mask,
                                                    alignment_sequence)
            self._attn_sup_loss(attn_sup_loss.detach().cpu().item())

            output_dict["loss"] = loss + self._attn_loss_lambda * attn_sup_loss

        return output_dict

    def _forward_beam_search(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Make forward pass during prediction using a beam search."""
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full(
            (batch_size, ), fill_value=self._start_index)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities = self._beam_search.search(
            start_predictions, state, self.take_step)

        output_dict = {
            "class_log_probabilities": log_probabilities,
            "predictions": all_top_k_predictions,
        }
        return output_dict

    def _prepare_output_projections(self,
                                    last_predictions: torch.Tensor,
                                    state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:  # pylint: disable=line-too-long
        """
        Decode current state and last prediction to produce produce projections
        into the target space, which can then be used to get probabilities of
        each target token for the next step.

        Inputs are the same as for `take_step()`.
        """
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (group_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (group_size, decoder_output_dim)
        decoder_hidden = state["decoder_hidden"]

        # shape: (group_size, decoder_output_dim)
        decoder_context = state["decoder_context"]

        # shape: (group_size, target_embedding_dim)
        embedded_input = self._target_embedder(last_predictions)

        # shape: (group_size, encoder_output_dim)
        attended_input, input_weights = self._prepare_attended_input(
            decoder_hidden, encoder_outputs, source_mask)

        # shape: (group_size, decoder_output_dim + target_embedding_dim)
        decoder_input = torch.cat((attended_input, embedded_input), -1)
        decoder_input = self._dec_dropout(decoder_input)

        # shape (decoder_hidden): (batch_size, decoder_output_dim)
        # shape (decoder_context): (batch_size, decoder_output_dim)
        decoder_hidden, decoder_context = self._decoder_cell(
            decoder_input, (decoder_hidden, decoder_context))

        state["decoder_hidden"] = decoder_hidden
        state["decoder_context"] = decoder_context

        # shape: (group_size, num_classes)
        output_projections = self._output_projection_layer(
            self._dec_dropout(decoder_hidden))

        return input_weights, output_projections, state

    def _prepare_attended_input(
        self,
        decoder_hidden_state: torch.LongTensor = None,
        encoder_outputs: torch.LongTensor = None,
        encoder_outputs_mask: torch.LongTensor = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply attention over encoder outputs and decoder state."""
        # Ensure mask is also a FloatTensor. Or else the multiplication within
        # attention will complain.
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs_mask = encoder_outputs_mask.float()
        # shape: (batch_size, max_input_sequence_length)
        input_logits = self._attention(decoder_hidden_state, encoder_outputs,
                                       encoder_outputs_mask)
        # the attention mechanism returns the logits that are necessary for attention supervision loss,
        # so we normalize it here
        input_weights = masked_softmax(input_logits, encoder_outputs_mask)
        # shape: (batch_size, encoder_output_dim)
        attended_input = util.weighted_sum(encoder_outputs, input_weights)

        return attended_input, input_logits

    @staticmethod
    def _get_attn_sup_loss(attn_weights: torch.Tensor,
                           alignment_mask: torch.Tensor,
                           alignment_sequence: torch.Tensor) -> torch.Tensor:
        """
        Compute the attention supervision CE loss.
        For each step, take the index of the aligned
        """
        # shape: (batch_size, max_decoding_steps, max_input_seq_length
        attn_weights = attn_weights.float()

        alignment_sequence[alignment_sequence == -1] = 0
        # for each attn_weights[batch_index, step_index, :] I want to choose the index of
        # alignment_sequence[batch_index, step_index]
        return util.sequence_cross_entropy_with_logits(attn_weights,
                                                       alignment_sequence,
                                                       alignment_mask)

    def _get_alignment_mask(self, alignment_sequence):
        """
        The alignment mask includes the target mask + mask on steps that don't have alignment
        shape: batch_size, max_steps, max_input
        """
        pad_mask = alignment_sequence != self._indexfield_padding_index
        missing_mask = alignment_sequence != self._missing_alignment_int

        return pad_mask * missing_mask

    @staticmethod
    def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor,
                  target_mask: torch.LongTensor) -> torch.Tensor:
        """
        Compute loss.

        Takes logits (unnormalized outputs from the decoder) of size (batch_size,
        num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
        and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
        entropy loss while taking the mask into account.

        The length of ``targets`` is expected to be greater than that of ``logits`` because the
        decoder does not need to compute the output corresponding to the last timestep of
        ``targets``. This method aligns the inputs appropriately to compute the loss.

        During training, we want the logit corresponding to timestep i to be similar to the target
        token from timestep i + 1. That is, the targets should be shifted by one timestep for
        appropriate comparison.  Consider a single example where the target has 3 words, and
        padding is to 7 tokens.
           The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
           and the mask would be                     1   1   1   1   1   0   0
           and let the logits be                     l1  l2  l3  l4  l5  l6
        We actually need to compare:
           the sequence           w1  w2  w3  <E> <P> <P>
           with masks             1   1   1   1   0   0
           against                l1  l2  l3  l4  l5  l6
           (where the input was)  <S> w1  w2  w3  <E> <P>
        """
        # shape: (batch_size, num_decoding_steps)
        relevant_targets = targets[:, 1:].contiguous()

        # shape: (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous()

        return util.sequence_cross_entropy_with_logits(logits,
                                                       relevant_targets,
                                                       relevant_mask)

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if not self.training:
            if self._bleu:
                all_metrics.update(self._bleu.get_metric(reset=reset))
            all_metrics.update(
                self._token_based_metric.get_metric(reset=reset))
            if self._sql_metrics:
                all_metrics.update(self._kb_match.get_metric(reset=reset))
                all_metrics.update(
                    self._schema_free_match.get_metric(reset=reset))
            all_metrics['attn_sup_loss'] = self._attn_sup_loss.get_metric(
                reset=reset)
        return all_metrics