Пример #1
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 contextualizer: Seq2SeqEncoder = None,
                 dropout: float = 0.0,
                 num_samples: int = None,
                 sparse_embeddings: bool = False,
                 bidirectional: bool = False,
                 initializer=InitializerApplicator(),
                 **kwargs) -> None:
        super().__init__(vocab, **kwargs)
        self._text_field_embedder = text_field_embedder
        self._contextualizer = contextualizer
        self._bidirectional = bidirectional

        if self._bidirectional:
            self._forward_dim = contextualizer.get_output_dim() // 2
        else:
            self._forward_dim = contextualizer.get_output_dim()

        self._softmax_loss = SoftmaxLoss(num_words=vocab.get_vocab_size(),
                                         embedding_dim=self._forward_dim)

        self._perplexity = Perplexity()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = lambda x: x

        if initializer is not None:
            initializer(self)
Пример #2
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        language_model_head: LanguageModelHead,
        contextualizer: Seq2SeqEncoder = None,
        target_namespace: str = "bert",
        dropout: float = 0.0,
        initializer: InitializerApplicator = None,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)
        self._text_field_embedder = text_field_embedder
        self._contextualizer = contextualizer
        if contextualizer:
            check_dimensions_match(
                text_field_embedder.get_output_dim(),
                contextualizer.get_input_dim(),
                "text field embedder output",
                "contextualizer input",
            )
        self._language_model_head = language_model_head
        self._target_namespace = target_namespace
        self._perplexity = Perplexity()
        self._dropout = torch.nn.Dropout(dropout)

        if initializer is not None:
            initializer(self)
Пример #3
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        contextualizer: Seq2SeqEncoder,
        dropout: float = None,
        num_samples: int = None,
        sparse_embeddings: bool = False,
        bidirectional: bool = False,
        initializer: InitializerApplicator = None,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)
        self._text_field_embedder = text_field_embedder

        if contextualizer.is_bidirectional() is not bidirectional:
            raise ConfigurationError(
                "Bidirectionality of contextualizer must match bidirectionality of "
                "language model. "
                f"Contextualizer bidirectional: {contextualizer.is_bidirectional()}, "
                f"language model bidirectional: {bidirectional}")

        self._contextualizer = contextualizer
        self._bidirectional = bidirectional

        # The dimension for making predictions just in the forward
        # (or backward) direction.
        if self._bidirectional:
            self._forward_dim = contextualizer.get_output_dim() // 2
        else:
            self._forward_dim = contextualizer.get_output_dim()

        if num_samples is not None:
            self._softmax_loss = SampledSoftmaxLoss(
                num_words=vocab.get_vocab_size("transactions"),
                embedding_dim=self._forward_dim,
                num_samples=num_samples,
                sparse=sparse_embeddings,
            )
        else:
            self._softmax_loss = SoftmaxLoss(
                num_words=vocab.get_vocab_size("transactions"),
                embedding_dim=self._forward_dim,
            )

        # This buffer is now unused and exists only for backwards compatibility reasons.
        self.register_buffer("_last_average_loss", torch.zeros(1))

        self._perplexity = Perplexity()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = lambda x: x

        if initializer is not None:
            initializer(self)
Пример #4
0
 def __init__(
     self,
     vocab: Vocabulary,
     encoder=None,
     source_encoder=None,
     trainable: bool = True,
     regularizer: Optional[RegularizerApplicator] = None,
 ) -> None:
     super().__init__(vocab, regularizer)
     self._bleu = BLEU()
     self._perplexity = Perplexity()
Пример #5
0
    def __init__(self, vocab: Vocabulary, embedding_dim: int):
        super().__init__(vocab)
        self.vocab_size = vocab.get_vocab_size("tokens")
        self.pad_token_index = vocab.get_token_index("[PAD]")

        self.prediction_head = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim, embedding_dim),
            Activation.by_name('gelu')(),
            torch.nn.LayerNorm(embedding_dim, 1e-12),
            torch.nn.Linear(embedding_dim, self.vocab_size))

        self._accuracy = CategoricalAccuracy()
        self._perplexity = Perplexity()
Пример #6
0
class ContextualSeq2seq(Model):
    def __init__(
        self,
        vocab: Vocabulary,
        encoder=None,
        source_encoder=None,
        trainable: bool = True,
        regularizer: Optional[RegularizerApplicator] = None,
    ) -> None:
        super().__init__(vocab, regularizer)
        self._bleu = BLEU()
        self._perplexity = Perplexity()

    def forward(
        self,
        source_context_tokens: Dict[str, torch.LongTensor],
        source_tokens: Dict[str, torch.LongTensor],
        target_tokens: Dict[str, torch.LongTensor],
    ) -> Dict[str, torch.Tensor]:
        pass

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        pass

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics = {
            "bleu": self._bleu.get_metric(reset),
            "perplexity": self._perplexity.get_metric(reset),
        }
        return metrics
Пример #7
0
    def __init__(
        self,
        vocab: Vocabulary,
        sequence_field_embedder: TextFieldEmbedder,
        structure_field_embedder: TextFieldEmbedder,
        seq2seq_encoder: Seq2SeqEncoder,
        tokens_masker: Optional[TokensMasker] = None,
    ) -> None:
        super().__init__(vocab)
        self._sequence_field_embedder = sequence_field_embedder
        self._structure_field_embedder = structure_field_embedder
        self._seq2seq_encoder = seq2seq_encoder
        self._head = LinearLanguageModelHead(
            vocab=vocab,
            input_dim=self._seq2seq_encoder.get_output_dim(),
            vocab_namespace="sequence")
        self._tokens_masker = tokens_masker

        ignore_index = self.vocab.get_token_index(DEFAULT_PADDING_TOKEN)
        self._loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
        self._perplexity = Perplexity()
Пример #8
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        language_model_head: LanguageModelHead,
        contextualizer: Seq2SeqEncoder = None,
        target_namespace: str = "bert",
        dropout: float = 0.0,
        initializer: InitializerApplicator = None,
        n_best: int = 5,
        beam_search_generator: BeamSearchGenerator = None,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)
        self._text_field_embedder = text_field_embedder
        self._contextualizer = contextualizer
        if contextualizer:
            check_dimensions_match(
                text_field_embedder.get_output_dim(),
                contextualizer.get_input_dim(),
                "text field embedder output",
                "contextualizer input",
            )
        self._language_model_head = language_model_head
        self._target_namespace = target_namespace
        self._perplexity = Perplexity()
        self._dropout = torch.nn.Dropout(dropout)
        self._n_best = n_best
        self._beam_search_generator = beam_search_generator

        # Ensure beam_search_generator is compatable with text_field_embedder.
        if self._beam_search_generator is not None:
            self._beam_search_generator.validate_text_field_embedder(
                self._text_field_embedder)

        if initializer is not None:
            initializer(self)
Пример #9
0
    def __init__(
        self,
        task: str,
        vocab: Vocabulary,
        input_dim: int,
        pretrained_model: str,
        loss_weight: float = 1.0,
        metric: str = 'ppl',
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)

        # isnt this costly?, why not reuse our encoder?        
        mlm = AutoModelForMaskedLM.from_pretrained(pretrained_model)

        # R: This is somewhat (or very) ugly code, however not sure how to 
        # do this cleaner while supporting so many *ForMaskedLM models
        #
        # ps. I guess distilbert is missing, wasnt sure which to pick
        self.lm_config = mlm.config
        try:
            self.mlm = mlm.pred_layer
        except:
            try:
                self.mlm = mlm.cls
            except:
                try:
                    self.mlm = mlm.lm_head
                except:
                    try:
                        self.mlm = mlm.generator_lm_head
                    except:
                        try:
                            self.mlm = mlm.predictions
                        except:
                            logger.error(pretrained_model + ' not yet configured for masked language modeling')
                            exit(1)

        self.task = task
        self.input_dim = input_dim
        self.loss_weight = loss_weight
        self.loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
        
        self.metrics = {
            "ppl": Perplexity(),
        }
Пример #10
0
    def __init__(self, backbone: ModelBackbone, dropout: float = None) -> None:
        super(LanguageModelling, self).__init__(backbone)

        if not backbone.featurizer.has_word_features:
            raise ConfigurationError(
                "`LanguageModelling` defines a word-level next token language model. "
                "Please check your `features` configuration to enable at least `words` features."
            )

        self._forward_dim = backbone.encoder.get_output_dim()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = lambda x: x

        self.metrics = {"perplexity": Perplexity()}

        self._loss = SoftmaxLoss(
            num_words=vocabulary.words_vocab_size(self.backbone.vocab),
            embedding_dim=self.backbone.encoder.get_output_dim(),
        )
    def __init__(
        self,
        backbone: ModelBackbone,
        dropout: float = None,
        bidirectional: bool = False,
    ) -> None:
        super(LanguageModelling, self).__init__(backbone)

        self.bidirectional = bidirectional

        if not backbone.featurizer.has_word_features:
            raise ConfigurationError(
                "`LanguageModelling` defines a word-level next token language model. "
                "Please check your `features` configuration to enable at least `words` features."
            )

        if backbone.encoder.is_bidirectional() is not bidirectional:
            raise ConfigurationError(
                "Bidirectionality of contextualizer must match bidirectionality of "
                "language model. "
                f"Contextualizer bidirectional: {backbone.encoder.is_bidirectional()}, "
                f"language model bidirectional: {bidirectional}")

        if self.bidirectional:
            self._forward_dim = backbone.encoder.get_output_dim() // 2
        else:
            self._forward_dim = backbone.encoder.get_output_dim()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = lambda x: x

        self.metrics = {"perplexity": Perplexity()}

        self._loss = SoftmaxLoss(
            num_words=vocabulary.words_vocab_size(self.backbone.vocab),
            embedding_dim=self._forward_dim,
        )
Пример #12
0
class NextTokenLM(Model):
    """
    The `NextTokenLM` embeds some input tokens, contextualizes them, then predicts the next word,
    computing a loss against known target.

    If `BeamSearch` is given, this model will predict a sequence of next tokens.

    !!! NOTE
        This was developed for use in a demo, not for training.  You *definitely* don't want to
        train a language model using this code; it would be incredibly inefficient. But it does
        compute correct gradients of the loss, however, so you can use it for interesting visualization
        of the gradients of a pretrained model, and it appears to be fast enough to sample from, at
        least for one word at a time.

    # Parameters

    vocab : `Vocabulary`
    text_field_embedder : `TextFieldEmbedder`
        Used to embed the indexed tokens we get in `forward`.
    language_model_head : `LanguageModelHead`
        The `torch.nn.Module` that goes from the hidden states output by the contextualizer to
        logits over some output vocabulary.
    contextualizer : `Seq2SeqEncoder`, optional (default=`None`)
        Used to "contextualize" the embeddings.  This is optional because the contextualization
        might actually be done in the text field embedder.
    target_namespace : `str`, optional (default=`'bert'`)
        Namespace to use to convert predicted token ids to strings in
        `Model.make_output_human_readable`.
    dropout : `float`, optional (default=`0.0`)
        If specified, dropout is applied to the contextualized embeddings before computation of
        the softmax. The contextualized embeddings themselves are returned without dropout.
    n_best : `int`, optional (default = `5`)
        The number of best tokens to predict. If `beam_search` is given, this option is ignored.
    beam_search_generator : `BeamSearchGenerator`, optional (default = `None`)
        An optional `BeamSearchGenerator`. If given, the model will predict sequences of next
        tokens instead of just a single next token.
    """

    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        language_model_head: LanguageModelHead,
        contextualizer: Seq2SeqEncoder = None,
        target_namespace: str = "bert",
        dropout: float = 0.0,
        initializer: InitializerApplicator = None,
        n_best: int = 5,
        beam_search_generator: BeamSearchGenerator = None,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)
        self._text_field_embedder = text_field_embedder
        self._contextualizer = contextualizer
        if contextualizer:
            check_dimensions_match(
                text_field_embedder.get_output_dim(),
                contextualizer.get_input_dim(),
                "text field embedder output",
                "contextualizer input",
            )
        self._language_model_head = language_model_head
        self._target_namespace = target_namespace
        self._perplexity = Perplexity()
        self._dropout = torch.nn.Dropout(dropout)
        self._n_best = n_best
        self._beam_search_generator = beam_search_generator

        # Ensure beam_search_generator is compatable with text_field_embedder.
        if self._beam_search_generator is not None:
            self._beam_search_generator.validate_text_field_embedder(self._text_field_embedder)

        if initializer is not None:
            initializer(self)

    def forward(  # type: ignore
        self, tokens: TextFieldTensors, target_ids: TextFieldTensors = None
    ) -> Dict[str, torch.Tensor]:
        """
        Run a forward pass of the model, returning an output tensor dictionary with
        the following fields:

        - `"probabilities"`: a tensor of shape `(batch_size, n_best)` representing
          the probabilities of the predicted tokens, where `n_best`
          is either `self._n_best` or `beam_size` if using beam search.
        - `"top_indices"`: a tensor of shape `(batch_size, n_best, num_predicted_tokens)`
          containing the IDs of the predicted tokens, where `num_predicted_tokens` is just
          1 unless using beam search, in which case it depends on the parameters of the beam search.
        - `"token_ids"`: a tensor of shape `(batch_size, num_input_tokens)` containing the IDs
          of the input tokens.
        - `"loss"` (optional): the loss of the batch, only given if `target_ids` is not `None`.

        """
        output_dict = {
            "token_ids": util.get_token_ids_from_text_field_tensors(tokens),
        }

        # Shape: (batch_size, vocab_size)
        target_logits = self._next_token_scores(tokens)

        # Compute loss.
        if target_ids is not None:
            batch_size, vocab_size = target_logits.size()
            tmp = util.get_token_ids_from_text_field_tensors(target_ids)
            # In some scenarios, target_ids might be a topk list of token ids (e.g. sorted by probabilities).
            # Therefore, we need to make sure only one token per batch
            # Assume: first token in each batch is the most desirable one (e.g. highest probability)
            tmp = tmp[:, 0] if len(tmp.shape) == 2 else tmp
            assert len(tmp.shape) <= 2
            targets = tmp.view(batch_size)
            loss = torch.nn.functional.cross_entropy(target_logits, targets)
            self._perplexity(loss)
            output_dict["loss"] = loss

        if self._beam_search_generator is not None:
            # Dummy start predictions.
            # Shape: (batch_size,)
            start_predictions = torch.zeros(
                target_logits.size()[0], device=target_logits.device, dtype=torch.int
            )

            state = self._beam_search_generator.get_step_state(tokens)

            # Put this in here to avoid having to re-compute on the first step of beam search.
            state["start_target_logits"] = target_logits

            # Shape (top_indices): (batch_size, beam_size, num_predicted_tokens)
            # Shape (top_log_probs): (batch_size, beam_size)
            top_indices, top_log_probs = self._beam_search_generator.search(
                start_predictions, state, self._beam_search_step
            )

            # Shape: (batch_size, beam_size)
            top_probs = top_log_probs.exp()
        else:
            # Shape: (batch_size, vocab_size)
            probs = torch.nn.functional.softmax(target_logits, dim=-1)

            # Shape (both): (batch_size, n_best)
            # min here largely because tests use small vocab
            top_probs, top_indices = probs.topk(k=min(target_logits.size(-1), self._n_best), dim=-1)

            # Shape: (batch_size, n_best, 1)
            top_indices = top_indices.unsqueeze(-1)

        output_dict["top_indices"] = top_indices
        output_dict["probabilities"] = top_probs

        return output_dict

    def _next_token_scores(self, tokens: TextFieldTensors) -> torch.Tensor:
        """
        Get the unnormalized log probabilities of the potential next token.
        """
        # Shape: (batch_size, num_tokens, embedding_dim)
        embeddings = self._text_field_embedder(tokens)

        # Shape: (batch_size, num_tokens, encoding_dim)
        if self._contextualizer:
            mask = util.get_text_field_mask(embeddings)
            contextual_embeddings = self._contextualizer(embeddings, mask)
            final_embeddings = util.get_final_encoder_states(contextual_embeddings, mask)
        else:
            final_embeddings = embeddings[:, -1]

        # Shape: (batch_size, vocab_size)
        return self._language_model_head(self._dropout(final_embeddings))

    def _beam_search_step(
        self, predicted_tokens: torch.Tensor, state: Dict[str, torch.Tensor], step: int
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Step function to use with `BeamSearch`.

        `predicted_tokens` is a tensor of shape `(group_size,)` and
        `state` is a dictionary of tensors with the following fields:
        - "token_ids": shape `(group_size, num_tokens)`
        - "mask": shape `(group_size, num_tokens)`
        - "type_ids": shape `(group_size, num_tokens)`
        """
        assert self._beam_search_generator is not None

        if step == 0:
            # Shape: (group_size, vocab_size)
            start_target_logits = state.pop("start_target_logits")

            # Shape: (group_size, vocab_size)
            start_target_log_probs = torch.nn.functional.log_softmax(start_target_logits, dim=-1)

            return start_target_log_probs, state

        inputs = self._beam_search_generator.prepare_step_input(predicted_tokens, state)
        state = self._beam_search_generator.get_step_state(inputs)

        # Shape: (group_size, vocab_size)
        next_token_scores = self._next_token_scores(inputs)

        # Shape: (group_size, vocab_size)
        log_probs = torch.nn.functional.log_softmax(next_token_scores, dim=-1)

        return log_probs, state

    def get_metrics(self, reset: bool = False):
        return {"perplexity": self._perplexity.get_metric(reset=reset)}

    @overrides
    def make_output_human_readable(
        self, output_dict: Dict[str, torch.Tensor]
    ) -> Dict[str, torch.Tensor]:
        """
        Collects token strings from indices, adding two fields to the `output_dict`:

        - `"top_tokens"`: a list (for each instance in the batch) of lists (for each of
          the `n` best predictions) of lists of strings (for each token in each prediction).
        - `"tokens"`: a list of list (for each instance in the batch) of strings (for each
          input token in the instance).
        """
        # Gather predicted words.
        top_tokens = []
        # shape (output_dict["top_indices"]): (batch_size, n_best, num_predicted_tokens)
        for instance in output_dict["top_indices"]:
            # shape (instance): (n_best, num_predicted_tokens)
            instance_top_words = []
            for indices in instance:
                # shape (indices): (num_predicted_tokens,)
                instance_top_words.append(
                    [
                        self.vocab.get_token_from_index(
                            index.item(), namespace=self._target_namespace
                        )
                        for index in indices
                    ]
                )
            top_tokens.append(instance_top_words)

        # Gather input tokens.
        tokens = []
        for instance_tokens in output_dict["token_ids"]:
            tokens.append(
                [
                    self.vocab.get_token_from_index(
                        token_id.item(), namespace=self._target_namespace
                    )
                    for token_id in instance_tokens
                ]
            )

        output_dict["top_tokens"] = top_tokens  # type: ignore
        output_dict["tokens"] = tokens  # type: ignore
        return output_dict

    default_predictor = "next_token_lm"
Пример #13
0
class MaskedLanguageModel(Model):
    """
    The `MaskedLanguageModel` embeds some input tokens (including some which are masked),
    contextualizes them, then predicts targets for the masked tokens, computing a loss against
    known targets.

    NOTE: This was developed for use in a demo, not for training.  It's possible that it will still
    work for training a masked LM, but it is very likely that some other code would be much more
    efficient for that.  This `does` compute correct gradients of the loss, because we use that in
    our demo, so in principle it should be able to train a model, we just don't necessarily endorse
    that use.

    # Parameters

    vocab : `Vocabulary`
    text_field_embedder : `TextFieldEmbedder`
        Used to embed the indexed tokens we get in `forward`.
    language_model_head : `LanguageModelHead`
        The `torch.nn.Module` that goes from the hidden states output by the contextualizer to
        logits over some output vocabulary.
    contextualizer : `Seq2SeqEncoder`, optional (default=`None`)
        Used to "contextualize" the embeddings.  This is optional because the contextualization
        might actually be done in the text field embedder.
    target_namespace : `str`, optional (default=`'bert'`)
        Namespace to use to convert predicted token ids to strings in
        `Model.make_output_human_readable`.
    dropout : `float`, optional (default=`0.0`)
        If specified, dropout is applied to the contextualized embeddings before computation of
        the softmax. The contextualized embeddings themselves are returned without dropout.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        language_model_head: LanguageModelHead,
        contextualizer: Seq2SeqEncoder = None,
        target_namespace: str = "bert",
        dropout: float = 0.0,
        initializer: InitializerApplicator = None,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)
        self._text_field_embedder = text_field_embedder
        self._contextualizer = contextualizer
        if contextualizer:
            check_dimensions_match(
                text_field_embedder.get_output_dim(),
                contextualizer.get_input_dim(),
                "text field embedder output",
                "contextualizer input",
            )
        self._language_model_head = language_model_head
        self._target_namespace = target_namespace
        self._perplexity = Perplexity()
        self._dropout = torch.nn.Dropout(dropout)

        if initializer is not None:
            initializer(self)

    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        mask_positions: torch.BoolTensor,
        target_ids: TextFieldTensors = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        tokens : `TextFieldTensors`
            The output of `TextField.as_tensor()` for a batch of sentences.
        mask_positions : `torch.LongTensor`
            The positions in `tokens` that correspond to [MASK] tokens that we should try to fill
            in.  Shape should be (batch_size, num_masks).
        target_ids : `TextFieldTensors`
            This is a list of token ids that correspond to the mask positions we're trying to fill.
            It is the output of a `TextField`, purely for convenience, so we can handle wordpiece
            tokenizers and such without having to do crazy things in the dataset reader.  We assume
            that there is exactly one entry in the dictionary, and that it has a shape identical to
            `mask_positions` - one target token per mask position.
        """

        targets = None
        if target_ids is not None:
            targets = util.get_token_ids_from_text_field_tensors(target_ids)
        mask_positions = mask_positions.squeeze(-1)
        batch_size, num_masks = mask_positions.size()
        if targets is not None and targets.size() != mask_positions.size():
            raise ValueError(
                f"Number of targets ({targets.size()}) and number of masks "
                f"({mask_positions.size()}) are not equal")

        # Shape: (batch_size, num_tokens, embedding_dim)
        embeddings = self._text_field_embedder(tokens)

        # Shape: (batch_size, num_tokens, encoding_dim)
        if self._contextualizer:
            mask = util.get_text_field_mask(embeddings)
            contextual_embeddings = self._contextualizer(embeddings, mask)
        else:
            contextual_embeddings = embeddings

        # Does advanced indexing to get the embeddings of just the mask positions, which is what
        # we're trying to predict.
        batch_index = torch.arange(0, batch_size).long().unsqueeze(1)
        mask_embeddings = contextual_embeddings[batch_index, mask_positions]

        target_logits = self._language_model_head(
            self._dropout(mask_embeddings))

        vocab_size = target_logits.size(-1)
        probs = torch.nn.functional.softmax(target_logits, dim=-1)
        k = min(vocab_size,
                5)  # min here largely because tests use small vocab
        top_probs, top_indices = probs.topk(k=k, dim=-1)

        output_dict = {"probabilities": top_probs, "top_indices": top_indices}

        output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(
            tokens)

        if targets is not None:
            target_logits = target_logits.view(batch_size * num_masks,
                                               vocab_size)
            targets = targets.view(batch_size * num_masks)
            loss = torch.nn.functional.cross_entropy(target_logits, targets)
            self._perplexity(loss)
            output_dict["loss"] = loss

        return output_dict

    def get_metrics(self, reset: bool = False):
        return {"perplexity": self._perplexity.get_metric(reset=reset)}

    def make_output_human_readable(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        top_words = []
        for instance_indices in output_dict["top_indices"]:
            top_words.append([[
                self.vocab.get_token_from_index(
                    index.item(), namespace=self._target_namespace)
                for index in mask_positions
            ] for mask_positions in instance_indices])
        output_dict["words"] = top_words
        tokens = []
        for instance_tokens in output_dict["token_ids"]:
            tokens.append([
                self.vocab.get_token_from_index(
                    token_id.item(), namespace=self._target_namespace)
                for token_id in instance_tokens
            ])
        output_dict["tokens"] = tokens

        return output_dict

    default_predictor = "masked_language_model"
Пример #14
0
class LanguageModel(Model):
    """
    The ``LanguageModel`` applies a "contextualizing"
    ``Seq2SeqEncoder`` to uncontextualized embeddings, using a ``SoftmaxLoss``
    module (defined above) to compute the language modeling loss.

    If bidirectional is True,  the language model is trained to predict the next and
    previous tokens for each token in the input. In this case, the contextualizer must
    be bidirectional. If bidirectional is False, the language model is trained to only
    predict the next token for each token in the input; the contextualizer should also
    be unidirectional.

    If your language model is bidirectional, it is IMPORTANT that your bidirectional
    ``Seq2SeqEncoder`` contextualizer does not do any "peeking ahead". That is, for its
    forward direction it should only consider embeddings at previous timesteps, and for
    its backward direction only embeddings at subsequent timesteps. Similarly, if your
    language model is unidirectional, the unidirectional contextualizer should only
    consider embeddings at previous timesteps. If this condition is not met, your
    language model is cheating.

    Parameters
    ----------
    vocab: ``Vocabulary``
    text_field_embedder: ``TextFieldEmbedder``
        Used to embed the indexed tokens we get in ``forward``.
    contextualizer: ``Seq2SeqEncoder``
        Used to "contextualize" the embeddings. As described above,
        this encoder must not cheat by peeking ahead.
    dropout: ``float``, optional (default: None)
        If specified, dropout is applied to the contextualized embeddings before computation of
        the softmax. The contextualized embeddings themselves are returned without dropout.
    num_samples: ``int``, optional (default: None)
        If provided, the model will use ``SampledSoftmaxLoss``
        with the specified number of samples. Otherwise, it will use
        the full ``_SoftmaxLoss`` defined above.
    sparse_embeddings: ``bool``, optional (default: False)
        Passed on to ``SampledSoftmaxLoss`` if True.
    bidirectional: ``bool``, optional (default: False)
        Train a bidirectional language model, where the contextualizer
        is used to predict the next and previous token for each input token.
        This must match the bidirectionality of the contextualizer.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 contextualizer: Seq2SeqEncoder,
                 dropout: float = None,
                 num_samples: int = None,
                 sparse_embeddings: bool = False,
                 bidirectional: bool = False,
                 initializer: InitializerApplicator = None,
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)
        self._text_field_embedder = text_field_embedder

        if contextualizer.is_bidirectional() is not bidirectional:
            raise ConfigurationError(
                "Bidirectionality of contextualizer must match bidirectionality of "
                "language model. "
                f"Contextualizer bidirectional: {contextualizer.is_bidirectional()}, "
                f"language model bidirectional: {bidirectional}")

        self._contextualizer = contextualizer
        self._bidirectional = bidirectional

        # The dimension for making predictions just in the forward
        # (or backward) direction.
        if self._bidirectional:
            self._forward_dim = contextualizer.get_output_dim() // 2
        else:
            self._forward_dim = contextualizer.get_output_dim()

        # TODO(joelgrus): more sampled softmax configuration options, as needed.
        if num_samples is not None:
            self._softmax_loss = SampledSoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=self._forward_dim,
                num_samples=num_samples,
                sparse=sparse_embeddings)
        else:
            self._softmax_loss = _SoftmaxLoss(num_words=vocab.get_vocab_size(),
                                              embedding_dim=self._forward_dim)

        # This buffer is now unused and exists only for backwards compatibility reasons.
        self.register_buffer('_last_average_loss', torch.zeros(1))

        self._perplexity = Perplexity()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = lambda x: x

        if initializer is not None:
            initializer(self)

    def _get_target_token_embeddings(self, token_embeddings: torch.Tensor,
                                     mask: torch.Tensor,
                                     direction: int) -> torch.Tensor:
        # Need to shift the mask in the correct direction
        zero_col = token_embeddings.new_zeros(mask.size(0), 1).byte()
        if direction == 0:
            # forward direction, get token to right
            shifted_mask = torch.cat([zero_col, mask[:, 0:-1]], dim=1)
        else:
            shifted_mask = torch.cat([mask[:, 1:], zero_col], dim=1)
        return token_embeddings.masked_select(shifted_mask.unsqueeze(-1)).view(
            -1, self._forward_dim)

    def _compute_loss(
        self,
        lm_embeddings: torch.Tensor,
        token_embeddings: torch.Tensor,
        forward_targets: torch.Tensor,
        backward_targets: torch.Tensor = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # If bidirectional, lm_embeddings is shape (batch_size, timesteps, dim * 2)
        # If unidirectional, lm_embeddings is shape (batch_size, timesteps, dim)
        # forward_targets, backward_targets (None in the unidirectional case) are
        # shape (batch_size, timesteps) masked with 0
        if self._bidirectional:
            forward_embeddings, backward_embeddings = lm_embeddings.chunk(
                2, -1)
            backward_loss = self._loss_helper(1, backward_embeddings,
                                              backward_targets,
                                              token_embeddings)
        else:
            forward_embeddings = lm_embeddings
            backward_loss = None

        forward_loss = self._loss_helper(0, forward_embeddings,
                                         forward_targets, token_embeddings)
        return forward_loss, backward_loss

    def _loss_helper(
            self,  # pylint: disable=inconsistent-return-statements
            direction: int,
            direction_embeddings: torch.Tensor,
            direction_targets: torch.Tensor,
            token_embeddings: torch.Tensor) -> Tuple[int, int]:
        mask = direction_targets > 0
        # we need to subtract 1 to undo the padding id since the softmax
        # does not include a padding dimension

        # shape (batch_size * timesteps, )
        non_masked_targets = direction_targets.masked_select(mask) - 1

        # shape (batch_size * timesteps, embedding_dim)
        non_masked_embeddings = direction_embeddings.masked_select(
            mask.unsqueeze(-1)).view(-1, self._forward_dim)
        # note: need to return average loss across forward and backward
        # directions, but total sum loss across all batches.
        # Assuming batches include full sentences, forward and backward
        # directions have the same number of samples, so sum up loss
        # here then divide by 2 just below
        if not self._softmax_loss.tie_embeddings or not self._use_character_inputs:
            return self._softmax_loss(non_masked_embeddings,
                                      non_masked_targets)
        else:
            # we also need the token embeddings corresponding to the
            # the targets
            raise NotImplementedError(
                "This requires SampledSoftmaxLoss, which isn't implemented yet."
            )
            # pylint: disable=unreachable
            non_masked_token_embeddings = self._get_target_token_embeddings(
                token_embeddings, mask, direction)
            return self._softmax(non_masked_embeddings, non_masked_targets,
                                 non_masked_token_embeddings)

    def delete_softmax(self) -> None:
        """
        Remove the softmax weights. Useful for saving memory when calculating the loss
        is not necessary, e.g. in an embedder.
        """
        self._softmax_loss = None

    def num_layers(self) -> int:
        """
        Returns the depth of this LM. That is, how many layers the contextualizer has plus one for
        the non-contextual layer.
        """
        if hasattr(self._contextualizer, 'num_layers'):
            return self._contextualizer.num_layers + 1
        else:
            raise NotImplementedError(
                f"Contextualizer of type {type(self._contextualizer)} " +
                "does not report how many layers it has.")

    def forward(
        self,  # type: ignore
        source: Dict[str, torch.LongTensor]
    ) -> Dict[str, torch.Tensor]:
        """
        Computes the averaged forward (and backward, if language model is bidirectional)
        LM loss from the batch.

        Parameters
        ----------
        source: ``Dict[str, torch.LongTensor]``, required.
            The output of ``Batch.as_tensor_dict()`` for a batch of sentences. By convention,
            it's required to have at least a ``"tokens"`` entry that's the output of a
            ``SingleIdTokenIndexer``, which is used to compute the language model targets.

        Returns
        -------
        Dict with keys:

        ``'loss'``: ``torch.Tensor``
            forward negative log likelihood, or the average of forward/backward
            if language model is bidirectional
        ``'forward_loss'``: ``torch.Tensor``
            forward direction negative log likelihood
        ``'backward_loss'``: ``torch.Tensor`` or ``None``
            backward direction negative log likelihood. If language model is not
            bidirectional, this is ``None``.
        ``'lm_embeddings'``: ``Union[torch.Tensor, List[torch.Tensor]]``
            (batch_size, timesteps, embed_dim) tensor of top layer contextual representations or
            list of all layers. No dropout applied.
        ``'noncontextual_token_embeddings'``: ``torch.Tensor``
            (batch_size, timesteps, token_embed_dim) tensor of bottom layer noncontextual
            representations
        ``'mask'``: ``torch.Tensor``
            (batch_size, timesteps) mask for the embeddings
        """
        # pylint: disable=arguments-differ
        mask = get_text_field_mask(source)

        # shape (batch_size, timesteps, embedding_size)
        embeddings = self._text_field_embedder(source)

        # Either the top layer or all layers.
        contextual_embeddings: Union[
            torch.Tensor,
            List[torch.Tensor]] = self._contextualizer(embeddings, mask)

        return_dict = {}

        # If we have target tokens, calculate the loss.
        token_ids = source.get("tokens")
        if token_ids is not None:
            assert isinstance(contextual_embeddings, torch.Tensor)

            # Use token_ids to compute targets
            forward_targets = torch.zeros_like(token_ids)
            forward_targets[:, 0:-1] = token_ids[:, 1:]

            if self._bidirectional:
                backward_targets = torch.zeros_like(token_ids)
                backward_targets[:, 1:] = token_ids[:, 0:-1]
            else:
                backward_targets = None

            # add dropout
            contextual_embeddings_with_dropout = self._dropout(
                contextual_embeddings)

            # compute softmax loss
            forward_loss, backward_loss = self._compute_loss(
                contextual_embeddings_with_dropout, embeddings,
                forward_targets, backward_targets)

            num_targets = torch.sum((forward_targets > 0).long())
            if num_targets > 0:
                if self._bidirectional:
                    average_loss = 0.5 * (forward_loss +
                                          backward_loss) / num_targets.float()
                else:
                    average_loss = forward_loss / num_targets.float()
            else:
                average_loss = torch.tensor(0.0).to(forward_targets.device)  # pylint: disable=not-callable

            self._perplexity(average_loss)

            if num_targets > 0:
                return_dict.update({
                    'loss':
                    average_loss,
                    'forward_loss':
                    forward_loss / num_targets.float(),
                    'backward_loss': (backward_loss / num_targets.float()
                                      if backward_loss is not None else None),
                    'batch_weight':
                    num_targets.float()
                })
            else:
                # average_loss zero tensor, return it for all
                return_dict.update({
                    'loss':
                    average_loss,
                    'forward_loss':
                    average_loss,
                    'backward_loss':
                    average_loss if backward_loss is not None else None
                })

        return_dict.update({
            # Note: These embeddings do not have dropout applied.
            'lm_embeddings': contextual_embeddings,
            'noncontextual_token_embeddings': embeddings,
            'mask': mask
        })

        return return_dict

    def get_metrics(self, reset: bool = False):
        return {"perplexity": self._perplexity.get_metric(reset=reset)}
Пример #15
0
class MlmHead(Head):
    def __init__(self, vocab: Vocabulary, embedding_dim: int):
        super().__init__(vocab)
        self.vocab_size = vocab.get_vocab_size("tokens")
        self.pad_token_index = vocab.get_token_index("[PAD]")

        self.prediction_head = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim, embedding_dim),
            Activation.by_name('gelu')(),
            torch.nn.LayerNorm(embedding_dim, 1e-12),
            torch.nn.Linear(embedding_dim, self.vocab_size))

        self._accuracy = CategoricalAccuracy()
        self._perplexity = Perplexity()

    @overrides
    def forward(
        self,  # type: ignore
        encoded_masked_text: torch.Tensor,
        masked_text_labels: Optional[torch.Tensor] = None,
    ) -> Dict[str, torch.Tensor]:

        prediction_scores = self.prediction_head(encoded_masked_text)

        probs = F.softmax(prediction_scores, dim=-1)
        top_probs, top_indices = probs.topk(k=5, dim=-1)

        output_dict = {
            "prediction_probs": top_probs,
            "top_indices": top_indices
        }

        if masked_text_labels is not None:
            # Gather all masked tokens, i.e. all tokens that aren't -100 (= not masked) or padding
            not_modified_mask = (masked_text_labels == -100)
            padding_mask = (masked_text_labels == self.pad_token_index)
            loss_mask = (~(padding_mask | not_modified_mask))

            mask_predictions = prediction_scores[loss_mask]
            mask_labels = masked_text_labels[loss_mask]

            loss = F.cross_entropy(mask_predictions, mask_labels)
            self._perplexity(loss)
            output_dict["loss"] = loss
            output_dict["masked_text_labels"] = masked_text_labels

        return output_dict

    def get_metrics(self, reset: bool = False):
        return {"perplexity": self._perplexity.get_metric(reset=reset)}

    @overrides
    def make_output_human_readable(
            self, output_dict: Dict[str, Any]) -> Dict[str, Any]:
        top_words = []
        for instance_indices in output_dict["top_indices"]:
            top_words.append([[
                self.vocab.get_token_from_index(
                    index.item(), namespace=self._target_namespace)
                for index in mask_positions
            ] for mask_positions in instance_indices])
        output_dict["words"] = top_words
        tokens = []
        for instance_tokens in output_dict["token_ids"]:
            tokens.append([
                self.vocab.get_token_from_index(
                    token_id.item(), namespace=self._target_namespace)
                for token_id in instance_tokens
            ])
        output_dict["tokens"] = tokens

        return output_dict
Пример #16
0
class AutoregressiveLanguageModel(Model):
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        contextualizer: Seq2SeqEncoder,
        dropout: float = None,
        num_samples: int = None,
        sparse_embeddings: bool = False,
        bidirectional: bool = False,
        initializer: InitializerApplicator = None,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)
        self._text_field_embedder = text_field_embedder

        if contextualizer.is_bidirectional() is not bidirectional:
            raise ConfigurationError(
                "Bidirectionality of contextualizer must match bidirectionality of "
                "language model. "
                f"Contextualizer bidirectional: {contextualizer.is_bidirectional()}, "
                f"language model bidirectional: {bidirectional}")

        self._contextualizer = contextualizer
        self._bidirectional = bidirectional

        # The dimension for making predictions just in the forward
        # (or backward) direction.
        if self._bidirectional:
            self._forward_dim = contextualizer.get_output_dim() // 2
        else:
            self._forward_dim = contextualizer.get_output_dim()

        if num_samples is not None:
            self._softmax_loss = SampledSoftmaxLoss(
                num_words=vocab.get_vocab_size("transactions"),
                embedding_dim=self._forward_dim,
                num_samples=num_samples,
                sparse=sparse_embeddings,
            )
        else:
            self._softmax_loss = SoftmaxLoss(
                num_words=vocab.get_vocab_size("transactions"),
                embedding_dim=self._forward_dim,
            )

        # This buffer is now unused and exists only for backwards compatibility reasons.
        self.register_buffer("_last_average_loss", torch.zeros(1))

        self._perplexity = Perplexity()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = lambda x: x

        if initializer is not None:
            initializer(self)

    def _get_target_token_embeddings(self, token_embeddings: torch.Tensor,
                                     mask: torch.BoolTensor,
                                     direction: int) -> torch.Tensor:
        # Need to shift the mask in the correct direction
        zero_col = token_embeddings.new_zeros(mask.size(0),
                                              1).to(dtype=torch.bool)
        if direction == 0:
            # forward direction, get token to right
            shifted_mask = torch.cat([zero_col, mask[:, 0:-1]], dim=1)
        else:
            shifted_mask = torch.cat([mask[:, 1:], zero_col], dim=1)
        return token_embeddings.masked_select(shifted_mask.unsqueeze(-1)).view(
            -1, self._forward_dim)

    def _compute_loss(
        self,
        lm_embeddings: torch.Tensor,
        token_embeddings: torch.Tensor,
        forward_targets: torch.Tensor,
        backward_targets: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # If bidirectional, lm_embeddings is shape (batch_size, timesteps, dim * 2)
        # If unidirectional, lm_embeddings is shape (batch_size, timesteps, dim)
        # forward_targets, backward_targets (None in the unidirectional case) are
        # shape (batch_size, timesteps) masked with 0
        if self._bidirectional:
            forward_embeddings, backward_embeddings = lm_embeddings.chunk(
                2, -1)
            backward_loss = self._loss_helper(1, backward_embeddings,
                                              backward_targets,
                                              token_embeddings)
        else:
            forward_embeddings = lm_embeddings
            backward_loss = None

        forward_loss = self._loss_helper(0, forward_embeddings,
                                         forward_targets, token_embeddings)
        return forward_loss, backward_loss

    def _loss_helper(
        self,
        direction: int,
        direction_embeddings: torch.Tensor,
        direction_targets: torch.Tensor,
        token_embeddings: torch.Tensor,
    ) -> Tuple[int, int]:
        mask = direction_targets > 0
        # we need to subtract 1 to undo the padding id since the softmax
        # does not include a padding dimension

        # shape (batch_size * timesteps, )
        non_masked_targets = direction_targets.masked_select(mask) - 1

        # shape (batch_size * timesteps, embedding_dim)
        non_masked_embeddings = direction_embeddings.masked_select(
            mask.unsqueeze(-1)).view(-1, self._forward_dim)
        # note: need to return average loss across forward and backward
        # directions, but total sum loss across all batches.
        # Assuming batches include full sentences, forward and backward
        # directions have the same number of samples, so sum up loss
        # here then divide by 2 just below
        if not self._softmax_loss.tie_embeddings or not self._use_character_inputs:
            return self._softmax_loss(non_masked_embeddings,
                                      non_masked_targets)
        else:
            # we also need the token embeddings corresponding to the
            # the targets
            raise NotImplementedError(
                "This requires SampledSoftmaxLoss, which isn't implemented yet."
            )

            non_masked_token_embeddings = self._get_target_token_embeddings(
                token_embeddings, mask, direction)
            return self._softmax(non_masked_embeddings, non_masked_targets,
                                 non_masked_token_embeddings)

    def delete_softmax(self) -> None:
        """
        Remove the softmax weights. Useful for saving memory when calculating the loss
        is not necessary, e.g. in an embedder.
        """
        self._softmax_loss = None

    def num_layers(self) -> int:
        """
        Returns the depth of this LM. That is, how many layers the contextualizer has plus one for
        the non-contextual layer.
        """
        if hasattr(self._contextualizer, "num_layers"):
            return self._contextualizer.num_layers + 1
        else:
            raise NotImplementedError(
                f"Contextualizer of type {type(self._contextualizer)} " +
                "does not report how many layers it has.")

    def forward(self, transactions: TextFieldTensors,
                **kwargs) -> Dict[str, torch.Tensor]:

        mask = get_text_field_mask(transactions)

        # shape (batch_size, timesteps, embedding_size)
        embeddings = self._text_field_embedder(transactions)

        # Either the top layer or all layers.
        contextual_embeddings: Union[
            torch.Tensor,
            List[torch.Tensor]] = self._contextualizer(embeddings, mask)

        return_dict = {}

        # If we have target transactions, calculate the loss.
        token_id_dict = transactions.get("tokens")
        if token_id_dict is not None:
            token_ids = token_id_dict["tokens"]
            assert isinstance(contextual_embeddings, torch.Tensor)

            # Use token_ids to compute targets
            forward_targets = torch.zeros_like(token_ids)
            forward_targets[:, 0:-1] = token_ids[:, 1:]

            if self._bidirectional:
                backward_targets = torch.zeros_like(token_ids)
                backward_targets[:, 1:] = token_ids[:, 0:-1]
            else:
                backward_targets = None

            # add dropout
            contextual_embeddings_with_dropout = self._dropout(
                contextual_embeddings)

            # compute softmax loss
            forward_loss, backward_loss = self._compute_loss(
                contextual_embeddings_with_dropout,
                embeddings,
                forward_targets,
                backward_targets,
            )

            num_targets = torch.sum((forward_targets > 0).long())
            if num_targets > 0:
                if self._bidirectional:
                    average_loss = 0.5 * (forward_loss +
                                          backward_loss) / num_targets.float()
                else:
                    average_loss = forward_loss / num_targets.float()
            else:
                average_loss = torch.tensor(0.0).to(forward_targets.device)

            self._perplexity(average_loss)

            if num_targets > 0:
                return_dict.update({
                    "loss":
                    average_loss,
                    "forward_loss":
                    forward_loss / num_targets.float(),
                    "batch_weight":
                    num_targets.float(),
                })
                if backward_loss is not None:
                    return_dict[
                        "backward_loss"] = backward_loss / num_targets.float()
            else:
                # average_loss zero tensor, return it for all
                return_dict.update({
                    "loss": average_loss,
                    "forward_loss": average_loss
                })
                if backward_loss is not None:
                    return_dict["backward_loss"] = average_loss

        return_dict.update({
            # Note: These embeddings do not have dropout applied.
            "lm_embeddings": contextual_embeddings,
            "noncontextual_token_embeddings": embeddings,
            "mask": mask,
        })

        return return_dict

    def get_metrics(self, reset: bool = False):
        return {"perplexity": self._perplexity.get_metric(reset=reset)}
Пример #17
0
class NextTokenLM(Model):
    """
    The `NextTokenLM` embeds some input tokens, contextualizes them, then predicts the next word,
    computing a loss against known target.

    NOTE: This was developed for use in a demo, not for training.  You `definitely` don't want to
    train a language model using this code; it would be incredibly inefficient.  This `does`
    compute correct gradients of the loss, however, so you can use it for interesting visualization
    of the gradients of a pretrained model, and it appears to be fast enough to sample from, at
    least for one word at a time.  If you want to sample many tokens at a time, you'd want to
    re-use some intermediate computation, so you would either need to modify this code or use
    something else.

    # Parameters

    vocab : `Vocabulary`
    text_field_embedder : `TextFieldEmbedder`
        Used to embed the indexed tokens we get in `forward`.
    language_model_head : `LanguageModelHead`
        The `torch.nn.Module` that goes from the hidden states output by the contextualizer to
        logits over some output vocabulary.
    contextualizer : `Seq2SeqEncoder`, optional (default=None)
        Used to "contextualize" the embeddings.  This is optional because the contextualization
        might actually be done in the text field embedder.
    target_namespace : `str`, optional (default='bert')
        Namespace to use to convert predicted token ids to strings in `Model.decode`.
    dropout : `float`, optional (default=0.0)
        If specified, dropout is applied to the contextualized embeddings before computation of
        the softmax. The contextualized embeddings themselves are returned without dropout.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        language_model_head: LanguageModelHead,
        contextualizer: Seq2SeqEncoder = None,
        target_namespace: str = "bert",
        dropout: float = 0.0,
        initializer: InitializerApplicator = None,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)
        self._text_field_embedder = text_field_embedder
        self._contextualizer = contextualizer
        if contextualizer:
            check_dimensions_match(
                text_field_embedder.get_output_dim(),
                contextualizer.get_input_dim(),
                "text field embedder output",
                "contextualizer input",
            )
        self._language_model_head = language_model_head
        self._target_namespace = target_namespace
        self._perplexity = Perplexity()
        self._dropout = torch.nn.Dropout(dropout)

        if initializer is not None:
            initializer(self)

    def forward(  # type: ignore
            self,
            tokens: TextFieldTensors,
            target_ids: TextFieldTensors = None) -> Dict[str, torch.Tensor]:

        # Shape: (batch_size, num_tokens, embedding_dim)
        embeddings = self._text_field_embedder(tokens)
        batch_size = embeddings.size(0)

        # Shape: (batch_size, num_tokens, encoding_dim)
        if self._contextualizer:
            mask = util.get_text_field_mask(embeddings)
            contextual_embeddings = self._contextualizer(embeddings, mask)
            final_embeddings = util.get_final_encoder_states(
                contextual_embeddings, mask)
        else:
            final_embeddings = embeddings[:, -1]

        target_logits = self._language_model_head(
            self._dropout(final_embeddings))

        vocab_size = target_logits.size(-1)
        probs = torch.nn.functional.softmax(target_logits, dim=-1)
        k = min(vocab_size,
                5)  # min here largely because tests use small vocab
        top_probs, top_indices = probs.topk(k=k, dim=-1)

        output_dict = {"probabilities": top_probs, "top_indices": top_indices}

        output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(
            tokens)

        if target_ids is not None:
            targets = util.get_token_ids_from_text_field_tensors(
                target_ids).view(batch_size)
            target_logits = target_logits.view(batch_size, vocab_size)
            loss = torch.nn.functional.cross_entropy(target_logits, targets)
            self._perplexity(loss)
            output_dict["loss"] = loss

        return output_dict

    def get_metrics(self, reset: bool = False):
        return {"perplexity": self._perplexity.get_metric(reset=reset)}

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        top_words = []
        for instance_indices in output_dict["top_indices"]:
            top_words.append([[
                self.vocab.get_token_from_index(
                    index.item(), namespace=self._target_namespace)
                for index in instance_indices
            ]])
            output_dict["words"] = top_words
        tokens = []
        print(output_dict["token_ids"])
        for instance_tokens in output_dict["token_ids"]:
            tokens.append([
                self.vocab.get_token_from_index(
                    token_id.item(), namespace=self._target_namespace)
                for token_id in instance_tokens
            ])
        output_dict["tokens"] = tokens

        return output_dict
Пример #18
0
class MaskedLanguageModel(Model):
    def __init__(
        self,
        vocab: Vocabulary,
        sequence_field_embedder: TextFieldEmbedder,
        structure_field_embedder: TextFieldEmbedder,
        seq2seq_encoder: Seq2SeqEncoder,
        tokens_masker: Optional[TokensMasker] = None,
    ) -> None:
        super().__init__(vocab)
        self._sequence_field_embedder = sequence_field_embedder
        self._structure_field_embedder = structure_field_embedder
        self._seq2seq_encoder = seq2seq_encoder
        self._head = LinearLanguageModelHead(
            vocab=vocab,
            input_dim=self._seq2seq_encoder.get_output_dim(),
            vocab_namespace="sequence")
        self._tokens_masker = tokens_masker

        ignore_index = self.vocab.get_token_index(DEFAULT_PADDING_TOKEN)
        self._loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
        self._perplexity = Perplexity()

    def get_output_dim(self) -> int:
        return self._seq2seq_encoder.get_output_dim()

    def forward(
        self,
        sequence: TextFieldTensors,
        structure: TextFieldTensors,
        **kwargs,
    ) -> Dict[str, torch.Tensor]:
        mask = get_text_field_mask(sequence)

        if self._tokens_masker is not None:
            sequence, targets = self._tokens_masker.mask_tokens(sequence)
        else:
            targets = None

        sequence_embeddings = self._sequence_field_embedder(sequence)
        structure_embeddings = self._structure_field_embedder(structure)

        # TODO: replace with attention
        sequence_embeddings = torch.cat(
            (sequence_embeddings, structure_embeddings), dim=-1)

        contextual_embeddings = self._seq2seq_encoder(sequence_embeddings,
                                                      mask)

        # take PAD tokens into account when decoding
        logits = self._head(contextual_embeddings)

        output_dict = dict(contextual_embeddings=contextual_embeddings,
                           logits=logits,
                           mask=mask)

        if targets is not None:
            output_dict["loss"] = self._loss(
                logits.transpose(1, 2),
                # TODO: it is not always tokens-tokens
                targets["tokens"]["tokens"],
            )
            self._perplexity(output_dict["loss"])
        return output_dict

    def get_metrics(self, reset: bool = False):
        return {"perplexity": self._perplexity.get_metric(reset=reset)}
class BaseRollinRolloutDecoder(SeqDecoder):
    """
    An base decoder with rollin and rollout formulation that will be used to define the other decoders such as autoregressive decoder, reinforce decoder, SEARNN decoder, etc.
    Parameters
    ----------
    vocab : ``Vocabulary``, required
        Vocabulary containing source and target vocabularies. They may be under the same namespace
        (`tokens`) or the target tokens can have a different namespace, in which case it needs to
        be specified as `target_namespace`.
    decoder_net : ``DecoderNet``, required
        Module that contains implementation of neural network for decoding output elements
    max_decoding_steps : ``int``
        Maximum length of decoded sequences.
    target_embedder : ``Embedding``
        Embedder for target tokens.
    target_namespace : ``str``, optional (default = 'target_tokens')
        If the target side vocabulary is different from the source side's, you need to specify the
        target's namespace here. If not, we'll assume it is "tokens", which is also the default
        choice for the source side, and this might cause them to share vocabularies.
    beam_size : ``int``, optional (default = 4)
        Width of the beam for beam search.
    tensor_based_metric : ``Metric``, optional (default = None)
        A metric to track on validation data that takes raw tensors when its called.
        This metric must accept two arguments when called: a batched tensor
        of predicted token indices, and a batched tensor of gold token indices.
    token_based_metric : ``Metric``, optional (default = None)
        A metric to track on validation data that takes lists of lists of tokens
        as input. This metric must accept two arguments when called, both
        of type `List[List[str]]`. The first is a predicted sequence for each item
        in the batch and the second is a gold sequence for each item in the batch.
    scheduled_sampling_ratio : ``float`` optional (default = 0)
        Defines ratio between teacher forced training and real output usage. If its zero
        (teacher forcing only) and `decoder_net`supports parallel decoding, we get the output
        predictions in a single forward pass of the `decoder_net`.
    """

    default_implementation = "auto_regressive_seq_decoder"
    
    def __init__(self,
                 vocab: Vocabulary,
                 max_decoding_steps: int,
                 decoder_net: DecoderNet,
                 target_embedder: Embedding,
                 loss_criterion: LossCriterion,
                
                 generation_batch_size: int = 200,
                 use_in_seq2seq_mode: bool = False,
                 target_namespace: str = "tokens",
                 beam_size: int = None,
                 scheduled_sampling_ratio: float = 0.0,
                 scheduled_sampling_k: int = 100,
                 scheduled_sampling_type: str = 'uniform',
                 rollin_mode: str = 'mixed',
                 rollout_mode: str = 'learned',

                 dropout: float = None,
                 start_token: str = START_SYMBOL,
                 end_token: str = END_SYMBOL,
                 num_decoder_layers: int = 1,
                 mask_pad_and_oov: bool = False,
                 tie_output_embedding: bool = False,

                 rollout_mixing_prob:float = 0.5,

                 use_bleu: bool = False,
                 use_hamming: bool = False,

                 sample_rollouts: bool = False,
                 beam_search_sampling_temperature: float = 1.,
                 top_k=0, 
                 top_p=0,
                 tensor_based_metric: Metric = None,
                 tensor_based_metric_mask: Metric = None,
                 token_based_metric: Metric = None,
                 eval_beam_size: int = 1,
                ) -> None:
        super().__init__(target_embedder)

        self.current_device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
        self._vocab = vocab
        self._seq2seq_mode = use_in_seq2seq_mode

        # Decodes the sequence of encoded hidden states into e new sequence of hidden states.
        self._max_decoding_steps = max_decoding_steps
        self._generation_batch_size = generation_batch_size
        self._decoder_net = decoder_net

        self._target_namespace = target_namespace

        # TODO #4 (Kushal): Maybe make them modules so that we can add more of these later.
        # TODO #8 #7 (Kushal): Rename "mixed" rollin mode to "scheduled sampling".
        self._rollin_mode = rollin_mode
        self._rollout_mode = rollout_mode

        self._scheduled_sampling_ratio = scheduled_sampling_ratio
        self._scheduled_sampling_k = scheduled_sampling_k
        self._scheduled_sampling_type = scheduled_sampling_type
        self._sample_rollouts = sample_rollouts
        self._mask_pad_and_oov = mask_pad_and_oov

        self._rollout_mixing_prob = rollout_mixing_prob

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self._vocab.get_token_index(start_token, self._target_namespace)
        self._end_index = self._vocab.get_token_index(end_token, self._target_namespace)

        self._padding_index = self._vocab.get_token_index(DEFAULT_PADDING_TOKEN, self._target_namespace)
        self._oov_index = self._vocab.get_token_index(DEFAULT_OOV_TOKEN, self._target_namespace)

        if self._mask_pad_and_oov:
            self._vocab_mask = torch.ones(self._vocab.get_vocab_size(self._target_namespace),
                                            device=self.current_device) \
                                    .scatter(0, torch.tensor([self._padding_index, self._oov_index, self._start_index],
                                                                device=self.current_device),
                                                0)
        if use_bleu:
            pad_index = self._vocab.get_token_index(self._vocab._padding_token, self._target_namespace)  # pylint: disable=protected-access
            self._bleu = BLEU(exclude_indices={pad_index, self._end_index, self._start_index})
        else:
            self._bleu = None

        if use_hamming:
            self._hamming = HammingLoss()
        else:
            self._hamming = None

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1

        # TODO(Kushal): Pass in the arguments for sampled. Also, make sure you do not sample in case of Seq2Seq models.
        self._beam_search = SampledBeamSearch(self._end_index, 
                                                max_steps=max_decoding_steps, 
                                                beam_size=beam_size, temperature=beam_search_sampling_temperature)

        self._num_classes = self._vocab.get_vocab_size(self._target_namespace)

        if self.target_embedder.get_output_dim() != self._decoder_net.target_embedding_dim:
            raise ConfigurationError(
                "Target Embedder output_dim doesn't match decoder module's input." + 
                    f"target_embedder_dim: {self.target_embedder.get_output_dim()}, " + 
                    f"decoder input dim: {self._decoder_net.target_embedding_dim}."
            )

        self._ss_ratio = Average()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = lambda x: x

        self.training_iteration = 0
        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_net.get_output_dim(), self._num_classes)

        if tie_output_embedding:
            if self._output_projection_layer.weight.shape != self.target_embedder.weight.shape:
                raise ConfigurationError(
                    f"Can't tie embeddings with output linear layer, due to shape mismatch. " + 
                    f"{self._output_projection_layer.weight.shape} and {self.target_embedder.weight.shape}"
                )
            self._output_projection_layer.weight = self.target_embedder.weight

        self._loss_criterion = loss_criterion

        self._top_k = top_k
        self._top_p = top_p
        self._eval_beam_size = eval_beam_size
        self._mle_loss = MaximumLikelihoodLossCriterion()
        self._perplexity = Perplexity()

        # These metrics will be updated during training and validation
        self._tensor_based_metric = tensor_based_metric
        self._token_based_metric = token_based_metric
        self._tensor_based_metric_mask = tensor_based_metric_mask
        
        self._decode_tokens = partial(decode_tokens, 
                                    vocab=self._vocab,
                                    start_index=self._start_index,
                                    end_index=self._end_index)
                                    
    def get_output_dim(self):
        return self._decoder_net.get_output_dim()

    def rollin_policy(self,
                      timestep: int,
                      last_predictions: torch.LongTensor,
                      target_tokens: Dict[str, torch.Tensor] = None,
                      rollin_mode = None) -> torch.LongTensor:
        """ Roll-in policy to use.
            This takes in targets, timestep and last_predictions, and decide
            which to use for taking next step i.e., generating next token.
            What to do is decided by rolling mode. Options are
                - teacher_forcing,
                - learned,
                - mixed,

            By default the mode is mixed with scheduled_sampling_ratio=0.0. This 
            defaults to teacher_forcing. You can also explicitly run with teacher_forcing
            mode.

        Arguments:
            timestep {int} -- Current timestep decides which target token to use.
                              In case of teacher_forcing this is usually {t-1}^{th} timestep
                              for predicting t^{th} token.
            last_predictions {torch.LongTensor} -- {t-1}^th token predicted by the model.

        Keyword Arguments:
            targets {torch.LongTensor} -- Targets value if it is available. This will be
                                           available in training mode but not in inference mode. (default: {None})
            rollin_mode {str} -- Rollin mode. Options are
                                  teacher_forcing, learned, scheduled-sampling (default: {'teacher_forcing'})
        Returns:
            torch.LongTensor -- The method returns input token for predicting next token.
        """
        rollin_mode = rollin_mode or self._rollin_mode

        # For first timestep, you are passing start token, so don't do anything smart.
        if (timestep == 0 or
           # If no targets, no way to do teacher_forcing, so use your own predictions.
           target_tokens is None  or
           rollin_mode == 'learned'):
            # shape: (batch_size,)
            return last_predictions

        targets = util.get_token_ids_from_text_field_tensors(target_tokens)
        if rollin_mode == 'teacher_forcing':
            # shape: (batch_size,)
            input_choices = targets[:, timestep]
        elif rollin_mode == 'mixed':
            if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio:
                # Use gold tokens at test time and at a rate of 1 - self._scheduled_sampling_ratio
                # during training.
                # shape: (batch_size,)
                input_choices = last_predictions
            else:
                # shape: (batch_size,)
                input_choices = targets[:, timestep]
        else:
            raise ConfigurationError(f"invalid configuration for rollin policy: {rollin_mode}")
        return input_choices

    def copy_reference_policy(self,
                                timestep,
                                last_predictions: torch.LongTensor,
                                state: Dict[str, torch.Tensor],
                                target_tokens: Dict[str, torch.LongTensor],
                              ) -> torch.FloatTensor:
        targets = util.get_token_ids_from_text_field_tensors(target_tokens)
        seq_len = targets.size(1)
        
        batch_size = last_predictions.shape[0]
        if seq_len > timestep + 1:  # + 1 because timestep is an index, indexed at 0.
            # As we might be overriding  the next/predicted token/
            # We have to use the value corresponding to {t+1}^{th}
            # timestep.
            target_at_timesteps = targets[:, timestep + 1]
        else:
            # We have overshot the seq_len, so just repeat the
            # last token which is either _end_token or _pad_token.
            target_at_timesteps = targets[:, -1]

        # TODO: Add support to allow other types of reference policies.
        # target_logits: (batch_size, num_classes).
        # This tensor has 0 at targets and (near) -inf at other places.
        target_logits = (target_at_timesteps.new_zeros((batch_size, self._num_classes)) + 1e-45) \
                            .scatter_(dim=1,
                                      index=target_at_timesteps.unsqueeze(1),
                                      value=1.0).log()
        return target_logits, state
    
    def oracle_reference_policy(self, 
                                timestep: int,
                                last_predictions: torch.LongTensor,
                                state: Dict[str, torch.Tensor],
                                token_to_idx: Dict[str, int],
                                idx_to_token: Dict[int, str],
                               ) -> torch.FloatTensor:
        # TODO(Kushal): #5 This is a temporary fix. Ideally, we should have
        # an individual oracle for this which is different from cost function.
        assert hasattr(self._loss_criterion, "_rollout_cost_function") and \
                     hasattr(self._loss_criterion._rollout_cost_function, "_oracle"), \
                "For oracle reference policy, we will need noisy oracle loss function"

        start_time = time.time()
        target_logits, state = self._loss_criterion \
                                    ._rollout_cost_function \
                                    ._oracle \
                                    .reference_step_rollout(
                                        step=timestep,
                                        last_predictions=last_predictions,
                                        state=state,
                                        token_to_idx=token_to_idx,
                                        idx_to_token=idx_to_token)
        end_time = time.time()
        logger.info(f"Oracle Reference time: {end_time - start_time} s")
        return target_logits, state
    
    def rollout_policy(self,
                       timestep: int,
                       last_predictions: torch.LongTensor, 
                       state: Dict[str, torch.Tensor],
                       logits: torch.FloatTensor,
                       reference_policy:ReferencePolicyType,
                       rollout_mode: str = None,
                       rollout_mixing_func: RolloutMixingProbFuncType = None,
                      ) -> torch.FloatTensor:
        """Rollout policy to use.
           This takes in predicted logits at timestep {t}^{th} and
           depending upon the rollout_mode replaces some of the predictions
           with targets.

           The options for rollout mode are:
               - learned,
               - reference,
               - mixed.

        Arguments:
            timestep {int} -- Current timestep decides which target token to use.
                              In case of reference this is usually {t-1}^{th} timestep
                              for predicting t^{th} token.
            logits {torch.LongTensor} -- Logits generated by the model for {t}^{th} timestep.
                                         (batch_size, num_classes).

        Keyword Arguments:
            targets {torch.LongTensor} -- Targets value if it is available. This will be
                                available in training mode but not in inference mode. (default: {None})
            rollout_mode {str} -- Rollout mode: Options are:
                                    learned, reference, mixed. (default: {'learned'})
            rollout_mixing_func {RolloutMixingProbFuncType} -- Function to get mask to choose predicted logits vs targets in case of mixed
                                    rollouts.  (default: {0.5})

        Returns:
            torch.LongTensor -- The method returns logits with rollout policy applied.
        """
        rollout_mode = rollout_mode or self._rollout_mode
        output_logits = logits


        if rollout_mode == 'learned':
            # For learned rollout policy, just return the same logits.
            return output_logits, state

        target_logits, state = reference_policy(timestep, 
                                                last_predictions,
                                                state)

        batch_size = logits.size(0)
        if rollout_mode == 'reference':
             output_logits += target_logits
        elif rollout_mode == 'mixed':
            # Based on the mask (Value=1), copy target values.

            if rollout_mixing_func is not None:
                rollout_mixing_prob_tensor = rollout_mixing_func()
            else:
                # This returns a (batch_size, num_classes) boolean map where the rows are either all zeros or all ones.
                rollout_mixing_prob_tensor = torch.bernoulli(torch.ones(batch_size) * self._rollout_mixing_prob)

            rollout_mixing_mask = rollout_mixing_prob_tensor \
                                    .unsqueeze(1) \
                                    .expand(logits.shape) \
                                    .to(self.current_device)

            # The target_logits ranges from (-inf , 0), so, by adding those to logits,
            # we turn the values that are not target tokens to -inf, hence making the distribution
            # skew towards the target.
            output_logits += rollout_mixing_mask * target_logits
        else:
            raise ConfigurationError(f"Incompatible rollout mode: {rollout_mode}")
        return output_logits, state

    def take_step(self,
                  timestep: int,
                  last_predictions: torch.Tensor,
                  state: Dict[str, torch.Tensor],
                  rollin_policy:RollinPolicyType=default_rollin_policy,
                  rollout_policy:RolloutPolicyType=default_rollout_policy,
                 ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take a decoding step. This is called by the beam search class.

        Parameters
        ----------
        last_predictions : ``torch.Tensor``
            A tensor of shape ``(group_size,)``, which gives the indices of the predictions
            during the last time step.
        state : ``Dict[str, torch.Tensor]``
            A dictionary of tensors that contain the current state information
            needed to predict the next step, which includes the encoder outputs,
            the source mask, and the decoder hidden state and context. Each of these
            tensors has shape ``(group_size, *)``, where ``*`` can be any other number
            of dimensions.

        Returns
        -------
        Tuple[torch.Tensor, Dict[str, torch.Tensor]]
            A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities``
            is a tensor of shape ``(group_size, num_classes)`` containing the predicted
            log probability of each class for the next step, for each item in the group,
            while ``updated_state`` is a dictionary of tensors containing the encoder outputs,
            source mask, and updated decoder hidden state and context.

        Notes
        -----
            We treat the inputs as a batch, even though ``group_size`` is not necessarily
            equal to ``batch_size``, since the group may contain multiple states
            for each source sentence in the batch.
        """
        input_choices = rollin_policy(timestep, last_predictions)

        # State timestep which we might in _prepare_output_projections.
        state['timestep'] = timestep
        # shape: (group_size, num_classes)
        class_logits, state = self._prepare_output_projections(
                                                last_predictions=input_choices,
                                                state=state)

        if not self.training and self._mask_pad_and_oov:
            # This implementation is copied from masked_log_softmax from allennlp.nn.util.
            mask = (self._vocab_mask.expand(class_logits.shape) + 1e-45).log()
            # shape: (group_size, num_classes)
            class_logits = class_logits + mask

        # shape: (group_size, num_classes)
        class_logits, state = rollout_policy(timestep, last_predictions, state, class_logits)
        class_logits = top_k_top_p_filtering(class_logits, self._top_k, self._top_p)
        return class_logits, state

    @overrides
    def forward(self,  # type: ignore
                encoder_out: Dict[str, torch.LongTensor] = {},
                target_tokens: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Make foward pass with decoder logic for producing the entire target sequence.

        Parameters
        ----------
        target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None)
           Output of `Textfield.as_array()` applied on target `TextField`. We assume that the
           target tokens are also represented as a `TextField`.

        source_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None)
           The output of `TextField.as_array()` applied on the source `TextField`. This will be
           passed through a `TextFieldEmbedder` and then through an encoder.

        Returns
        -------
        Dict[str, torch.Tensor]
        """
        output_dict: Dict[str, torch.Tensor] = {}
        state: Dict[str, torch.Tensor] = {}
        decoder_init_state: Dict[str, torch.Tensor] = {}

        state.update(copy.copy(encoder_out))
        # In Seq2Seq setting, we will encode the source sequence,
        # and init the state object with encoder output and decoder
        # cell will use these encoder outputs for attention/initing
        # the decoder states.
        if self._seq2seq_mode:
            decoder_init_state = \
                        self._decoder_net.init_decoder_state(state)
            state.update(decoder_init_state)

       # Initialize target predictions with the start index.
        # shape: (batch_size,)
        start_predictions: torch.LongTensor = \
                self._get_start_predictions(state,
                                        target_tokens,
                                        self._generation_batch_size)
        
        # In case we have target_tokens, roll-in and roll-out
        # only till those many steps, otherwise we roll-out for
        # `self._max_decoding_steps`.
        if target_tokens:
            # shape: (batch_size, max_target_sequence_length)
            targets: torch.LongTensor = \
                    util.get_token_ids_from_text_field_tensors(target_tokens)

            _, target_sequence_length = targets.size()

            # The last input from the target is either padding or the end symbol.
            # Either way, we don't have to process it.
            num_decoding_steps: int = target_sequence_length - 1
        else:
            num_decoding_steps: int = self._max_decoding_steps

        if target_tokens:
            decoder_output_dict, rollin_dict, rollout_dict_iter = \
                                        self._forward_loop(
                                                state=state,
                                                start_predictions=start_predictions,
                                                num_decoding_steps=num_decoding_steps,
                                                target_tokens=target_tokens)

            output_dict.update(decoder_output_dict)
            predictions = decoder_output_dict['predictions']
            predicted_tokens = self._decode_tokens(predictions,
                                                    vocab_namespace=self._target_namespace,
                                                    truncate=True)
            output_dict["decoded_predictions"] = predicted_tokens

            decoded_targets = self._decode_tokens(targets,
                                    vocab_namespace=self._target_namespace,
                                    truncate=True)
            output_dict["decoded_targets"] = decoded_targets

            output_dict.update(self._loss_criterion(
                                            rollin_output_dict=rollin_dict, 
                                            rollout_output_dict_iter=rollout_dict_iter, 
                                            state=state, 
                                            target_tokens=target_tokens))

            mle_loss_output = self._mle_loss(
                                    rollin_output_dict=rollin_dict, 
                                    rollout_output_dict_iter=rollout_dict_iter, 
                                    state=state, 
                                    target_tokens=target_tokens)

            mle_loss = mle_loss_output['loss']
            self._perplexity(mle_loss)

        if not self.training:
            # While validating or testing we need to roll out the learned policy and the output
            # of this rollout is used to compute the secondary metrics
            # like BLEU.
            state: Dict[str, torch.Tensor] = {}
            state.update(copy.copy(encoder_out))
            state.update(decoder_init_state)

            rollout_output_dict = self.rollout(state,
                                        start_predictions,
                                        rollout_steps=num_decoding_steps,
                                        rollout_mode='learned',
                                        sampled=self._sample_rollouts,
                                        beam_size=self._eval_beam_size,
                                        # TODO #6 (Kushal): Add a reason why truncate_at_end_all is False here.
                                        truncate_at_end_all=False)

            output_dict.update(rollout_output_dict)

            predictions = decoder_output_dict['predictions']
            predicted_tokens = self._decode_tokens(predictions,
                                                vocab_namespace=self._target_namespace,
                                                truncate=True)
            output_dict["decoded_predictions"] = predicted_tokens
            decoded_predictions = [predictions[0] \
                                    for predictions in output_dict["decoded_predictions"]]


            # shape (predictions): (batch_size, beam_size, num_decoding_steps)
            predictions = rollout_output_dict['predictions']

            # shape (best_predictions): (batch_size, num_decoding_steps)
            best_predictions = predictions[:, 0, :]

            if target_tokens:
                targets = util.get_token_ids_from_text_field_tensors(target_tokens)
                target_mask = util.get_text_field_mask(target_tokens)
                decoded_targets = self._decode_tokens(targets,
                                        vocab_namespace=self._target_namespace,
                                        truncate=True)

                # TODO #3 (Kushal): Maybe abstract out these losses and use loss_metric like AllenNLP uses.
                if self._bleu and target_tokens:
                    self._bleu(best_predictions, targets)

                if  self._hamming and target_tokens:
                    self._hamming(best_predictions, targets, target_mask)

                if self._tensor_based_metric is not None:
                    self._tensor_based_metric(  # type: ignore
                        predictions=best_predictions,
                        gold_targets=targets,
                    )
                if self._tensor_based_metric_mask is not None:
                    self._tensor_based_metric_mask(  # type: ignore
                        predictions=best_predictions,
                        gold_targets=targets,
                        mask=~target_mask,
                    )

                if self._token_based_metric is not None:
                    self._token_based_metric(  # type: ignore
                            predictions=decoded_predictions, 
                            gold_targets=decoded_targets,
                        )
        return output_dict

    @overrides
    def post_process(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        predicted_indices = output_dict["predictions"]
        all_predicted_tokens = self._decode_tokens(predicted_indices, 
                                                    vocab_namespace=self._target_namespace,
                                                    truncate=True)
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict

    def _apply_scheduled_sampling(self):
        if not self.training:
            raise RuntimeError("Scheduled Sampling can only be applied during training.")

        k = self._scheduled_sampling_k
        i = self.training_iteration
        if self._scheduled_sampling_type == 'uniform':
            # This is same scheduled sampling ratio set by config.
            pass
        elif self._scheduled_sampling_type == 'linear':
            self._scheduled_sampling_ratio = i/float(k)
        elif self._scheduled_sampling_type == 'inverse_sigmoid':
            self._scheduled_sampling_ratio = 1 - k/(k + math.exp(i/k))
        else:
            raise ConfigurationError(f"{self._scheduled_sampling_type} is not a valid scheduled sampling type.")

        self._ss_ratio(self._scheduled_sampling_ratio)

    def rollin(self,
               state: Dict[str, torch.Tensor],
               start_predictions: torch.LongTensor,
               rollin_steps: int,
               target_tokens: Dict[str, torch.LongTensor] = None,
               beam_size: int = 1,
               per_node_beam_size: int = None,
               sampled: bool = False,
               truncate_at_end_all: bool = False,
               rollin_mode: str = None,
              ):
        self.training_iteration += 1

        # We cannot make a class variable as default, so making default value
        # as None and in case it is None, setting it to num_classes.
        per_node_beam_size: int = per_node_beam_size or self._num_classes

        if self.training:
            self._apply_scheduled_sampling()

        rollin_policy = partial(self.rollin_policy,
                                target_tokens=target_tokens,
                                rollin_mode=rollin_mode)

        rolling_policy = partial(self.take_step,
                                 rollin_policy=rollin_policy)

        # shape (step_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        # shape (logits): (batch_size, beam_size, num_decoding_steps, num_classes)
        step_predictions, log_probabilities, logits = \
                    self._beam_search.search(start_predictions,
                                                state,
                                                rolling_policy,
                                                max_steps=rollin_steps,
                                                beam_size=beam_size,
                                                per_node_beam_size=per_node_beam_size,
                                                sampled=sampled,
                                                truncate_at_end_all=truncate_at_end_all)

        logits = torch.cat(logits, dim=2)

        batch_size, beam_size, _ = step_predictions.shape
        start_prediction_length = start_predictions.size(0)
        step_predictions = torch.cat([start_predictions.unsqueeze(1) \
                                        .expand(batch_size, beam_size) \
                                        .reshape(batch_size, beam_size, 1),
                                        step_predictions],
                                     dim=-1)

        output_dict = {
            "predictions": step_predictions,
            "logits": logits,
            "class_log_probabilities": log_probabilities,
        }
        return output_dict

    def rollin_parallel(self, 
                        state: Dict[str, torch.Tensor],
                        start_predictions: torch.LongTensor,
                        rollin_steps: int,
                        target_tokens: Dict[str, torch.LongTensor] = None,
                        beam_size: int = 1,
                        per_node_beam_size: int = None,
                        sampled: bool = False,
                        truncate_at_end_all: bool = False,
                        rollin_mode: str = None,
                    ):
        assert self._decoder_net.decodes_parallel, \
            "Rollin Parallel is only applicable for transformer style decoders" + \
            "that decode whole sequence in parallel."
        
        assert not rollin_mode or rollin_mode == "learned", \
            "Parallel Decoding only works when following " + \
            "teacher forcing rollin policy (rollin_mode='learned')."

        assert self._scheduled_sampling_ratio == 0, \
            "For learned rollin mode, scheduled sampling ratio should always be 0."

        self.training_iteration += 1

        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (batch_size, max_target_sequence_length)
        targets = util.get_token_ids_from_text_field_tensors(target_tokens)

        # Prepare embeddings for targets. They will be used as gold embeddings during decoder training
        # shape: (batch_size, max_target_sequence_length, embedding_dim)
        target_embedding = self.target_embedder(targets)

        # shape: (batch_size, max_target_batch_sequence_length)
        target_mask = util.get_text_field_mask(target_tokens)

        _, decoder_output = self._decoder_net(
            previous_state=state,
            previous_steps_predictions=target_embedding[:, :-1, :],
            encoder_outputs=encoder_outputs,
            source_mask=source_mask,
            previous_steps_mask=target_mask[:, :-1],
        )

        # shape: (group_size, max_target_sequence_length, num_classes)
        logits = self._output_projection_layer(decoder_output)

        # Unsqueeze logit to add beam size dimension.
        logits = logits.unsqueeze(dim=1)

        log_probabilities, step_predictions = torch.max(logits, dim=-1)

        return {
            "predictions": step_predictions,
            "logits": logits,
            "class_log_probabilities": log_probabilities,
        }

    def rollout(self,
                state: Dict[str, torch.Tensor],
                start_predictions: torch.LongTensor,
                rollout_steps: int,
                beam_size: int = None,
                per_node_beam_size: int = None,
                target_tokens: Dict[str, torch.LongTensor] = None,
                sampled: bool = True,
                truncate_at_end_all: bool = True,
                # shape (prediction_prefixes): (batch_size, prefix_length)
                prediction_prefixes: torch.LongTensor = None,
                target_prefixes: torch.LongTensor = None,
                rollout_mixing_func: RolloutMixingProbFuncType = None,
                reference_policy_type:str = "copy",
                rollout_mode: str = None,
               ):
        state['rollout_params'] = {}
        if reference_policy_type == 'oracle':
            reference_policy = partial(self.oracle_reference_policy,
                                        token_to_idx=self._vocab._token_to_index['target_tokens'],
                                        idx_to_token=self._vocab._index_to_token['target_tokens'],
                                       )
            num_steps_to_take = rollout_steps
            state['rollout_params']['rollout_prefixes'] = prediction_prefixes
        else:
            reference_policy = partial(self.copy_reference_policy,
                                        target_tokens=target_tokens)           
            num_steps_to_take = rollout_steps

        rollout_policy = partial(self.rollout_policy,
                                    rollout_mode=rollout_mode,
                                    rollout_mixing_func=rollout_mixing_func,
                                    reference_policy=reference_policy,
                                )
        rolling_policy=partial(self.take_step,
                               rollout_policy=rollout_policy)

        # shape (step_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        # shape (logits): (batch_size, beam_size, num_decoding_steps, num_classes)
        step_predictions, log_probabilities, logits = \
                    self._beam_search.search(start_predictions,
                                                state,
                                                rolling_policy,
                                                max_steps=num_steps_to_take,
                                                beam_size=beam_size,
                                                per_node_beam_size=per_node_beam_size,
                                                sampled=sampled,
                                                truncate_at_end_all=truncate_at_end_all)

        logits = torch.cat(logits, dim=2)
        
        # Concatenate the start tokens to the predictions.They are not
        # added to the predictions by default.
        batch_size, beam_size, _ = step_predictions.shape

        start_prediction_length = start_predictions.size(0)
        step_predictions = torch.cat([start_predictions.unsqueeze(1) \
                                        .expand(batch_size, beam_size) \
                                        .reshape(batch_size, beam_size, 1),
                                        step_predictions],
                                        dim=-1)

        # There might be some predictions which might have been made by
        # rollin policy. If passed, concatenate them here.
        if prediction_prefixes is not None:
            prefixes_length = prediction_prefixes.size(1)
            step_predictions = torch.cat([prediction_prefixes.unsqueeze(1)\
                                            .expand(batch_size, beam_size, prefixes_length), 
                                         step_predictions],
                                         dim=-1)

        step_prediction_masks = self._get_mask(step_predictions \
                                                .reshape(batch_size * beam_size, -1)) \
                                        .reshape(batch_size, beam_size, -1)

        output_dict = {
            "predictions": step_predictions,
            "prediction_masks": step_prediction_masks,
            "logits": logits,
            "class_log_probabilities": log_probabilities,
        }

        step_targets = None
        step_target_masks = None
        if target_tokens is not None:
            step_targets = util.get_token_ids_from_text_field_tensors(target_tokens)
            if target_prefixes is not None:
                prefixes_length = target_prefixes.size(1)
                step_targets = torch.cat([target_prefixes, step_targets], dim=-1)

            step_target_masks = util.get_text_field_mask({'tokens': {'tokens': step_targets}})
            
            output_dict.update({
                "targets": step_targets,
                "target_masks": step_target_masks,
            })
        return output_dict

    def compute_sentence_probs(self,
                               sequences_dict: Dict[str, torch.LongTensor],
                              ) -> torch.FloatTensor:
        """ Given a batch of tokens, compute the per-token log probability of sequences
            given the trained model.

        Arguments:
            sequences_dict {Dict[str, torch.LongTensor]} -- The sequences that needs to be scored.

        Returns:
            seq_probs {torch.FloatTensor} -- Probabilities of the sequence.
            seq_lens {torch.LongTensor} -- Length of the non padded sequence.
            per_step_seq_probs {torch.LongTensor} -- Probability of per prediction in a sequence
        """
        state = {}
        sequences = util.get_token_ids_from_text_field_tensors(sequences_dict)

        batch_size = sequences.size(0)
        seq_len = sequences.size(1)
        start_predictions = self._get_start_predictions(state,
                                                        sequences_dict,
                                                        batch_size)
        
        # We are now computing probability considering given the sequence,
        # So, we will use rollin_mode=teacher_forcing as we want to select
        # token from the sequences for which we need to compute the probability.
        rollin_output_dict = self.rollin(state={},
                                            start_predictions=start_predictions,
                                            rollin_steps=seq_len - 1,
                                            target_tokens=sequences_dict,
                                            rollin_mode='teacher_forcing',
                                        )

        step_log_probs = F.log_softmax(rollin_output_dict['logits'].squeeze(1), dim=-1)
        per_step_seq_probs = torch.gather(step_log_probs, 2,
                                          sequences[:,1:].unsqueeze(2)) \
                                            .squeeze(2)

        sequence_mask = util.get_text_field_mask(sequences_dict)
        per_step_seq_probs_summed = torch.sum(per_step_seq_probs * sequence_mask[:, 1:], dim=-1)
        non_batch_dims = tuple(range(1, len(sequence_mask.shape)))

        # shape : (batch_size,)
        sequence_mask_sum = sequence_mask[:, 1:].sum(dim=non_batch_dims)

        # (seq_probs, seq_lens, per_step_seq_probs)
        return torch.exp(per_step_seq_probs_summed/sequence_mask_sum), \
                sequence_mask_sum, \
                torch.exp(per_step_seq_probs)

    def _forward_loop(self,
                      state: Dict[str, torch.Tensor],
                      start_predictions: torch.LongTensor,
                      num_decoding_steps: int,
                      target_tokens: Dict[str, torch.LongTensor] = None,
                     ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
        raise NotImplementedError()

    def _get_start_predictions(self,
              state: Dict[str, torch.Tensor],
              target_tokens: Dict[str, torch.LongTensor] = None,
              generation_batch_size:int = None) ->  torch.LongTensor:

        if self._seq2seq_mode:
           source_mask = state["source_mask"]
           batch_size = source_mask.size()[0]
        elif target_tokens:
            targets = util.get_token_ids_from_text_field_tensors(target_tokens)
            batch_size = targets.size(0)
        else:
            batch_size = generation_batch_size

        # Initialize target predictions with the start index.
        # shape: (batch_size,)
        return torch.zeros((batch_size,),
                            dtype=torch.long,
                            device=self.current_device) \
                    .fill_(self._start_index)

    def _prepare_output_projections(self,
                                    last_predictions: torch.Tensor,
                                    state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Decode current state and last prediction to produce produce projections
        into the target space, which can then be used to get probabilities of
        each target token for the next step.

        Inputs are the same as for `take_step()`.
        """
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state.get("encoder_outputs", None)

        # shape: (group_size, max_input_sequence_length)
        source_mask = state.get("source_mask", None)

        # shape: (group_size, steps_count, decoder_output_dim)
        previous_steps_predictions = state.get("previous_steps_predictions", None)

        # shape: (batch_size, 1, target_embedding_dim)
        last_predictions_embeddings = self.target_embedder(last_predictions).unsqueeze(1)

        if previous_steps_predictions is None or previous_steps_predictions.shape[-1] == 0:
            # There is no previous steps, except for start vectors in `last_predictions`
            # shape: (group_size, 1, target_embedding_dim)
            previous_steps_predictions = last_predictions_embeddings
        else:
            # shape: (group_size, steps_count, target_embedding_dim)
            previous_steps_predictions = torch.cat(
                [previous_steps_predictions, last_predictions_embeddings], 1
            )

        decoder_state, decoder_output = self._decoder_net(
            previous_state=state,
            encoder_outputs=encoder_outputs,
            source_mask=source_mask,
            previous_steps_predictions=previous_steps_predictions,
        )
        
        state["previous_steps_predictions"] = previous_steps_predictions

        # Update state with new decoder state, override previous state
        state.update(decoder_state)
        
        if self._decoder_net.decodes_parallel:
            decoder_output = decoder_output[:, -1, :]
        
        # add dropout
        decoder_hidden_with_dropout = self._dropout(decoder_output)

        # shape: (group_size, num_classes)
        output_projections = self._output_projection_layer(decoder_hidden_with_dropout)

        return output_projections, state
  
    def _get_mask(self, predictions) -> torch.FloatTensor:
        # SEARNN with KL might not produce the sequences that
        # match target sequence on length. This is especially true
        # with LM done with learned rollins. The pattern observed
        # here is that sequence lengths keep shrinking.

        # This code computes mask from predicted tokens by observing
        # first time eos token is produced. Everything after that is
        # masked out.
        mask = predictions.new_ones(predictions.shape)
        for i, indices in enumerate(predictions.detach().cpu().tolist()):
            if self._end_index in indices:
                end_idx = indices.index(self._end_index)
                mask[i, :end_idx + 1] = 1
                mask[i, end_idx + 1:] = 0
        return mask
        
    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}

        all_metrics.update({
            'ss_ratio': self._ss_ratio.get_metric(reset=reset),
            'training_iter': self.training_iteration,
            'perplexity': self._perplexity.get_metric(reset=reset),
        })

        if self._bleu and not self.training:
            all_metrics.update(self._bleu.get_metric(reset=reset))

        if self._hamming and not self.training:
            all_metrics.update({'hamming': self._hamming.get_metric(reset=reset)})

        if self._loss_criterion and self._loss_criterion._shall_compute_rollout_loss:
            all_metrics.update(self._loss_criterion.get_metric(reset=reset))

        if not self.training:
            if self._tensor_based_metric is not None:
                all_metrics.update(
                    self._tensor_based_metric.get_metric(reset=reset)  # type: ignore
                )
            if self._token_based_metric is not None:
                all_metrics.update(self._token_based_metric.get_metric(reset=reset))  # type: ignore

        return all_metrics
Пример #20
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        contextualizer: Seq2SeqEncoder,
        aux_contextualizer: Seq2SeqEncoder,
        dropout: float = None,
        num_samples: int = None,
        sparse_embeddings: bool = False,
        bidirectional: bool = False,
        initializer: InitializerApplicator = None,
        regularizer: Optional[RegularizerApplicator] = None,
    ) -> None:
        super().__init__(vocab, regularizer)
        self._text_field_embedder = text_field_embedder

        if contextualizer.is_bidirectional() is not bidirectional:
            raise ConfigurationError(
                "Bidirectionality of contextualizer must match bidirectionality of "
                "language model. "
                f"Contextualizer bidirectional: {contextualizer.is_bidirectional()}, "
                f"language model bidirectional: {bidirectional}")

        self._contextualizer_lang1 = aux_contextualizer
        self._contextualizer_lang2 = copy.deepcopy(aux_contextualizer)
        self._contextualizer = contextualizer

        self._bidirectional = bidirectional
        self._bidirectional_aux = aux_contextualizer.is_bidirectional()

        # The dimension for making predictions just in the forward
        # (or backward) direction.
        # main contextualizer forward dim
        if self._bidirectional:
            self._forward_dim = contextualizer.get_output_dim() // 2
        else:
            self._forward_dim = contextualizer.get_output_dim()

        # aux contextualizer forward dim
        if self._bidirectional_aux:
            self._forward_dim_aux = aux_contextualizer.get_output_dim() // 2
        else:
            self._forward_dim_aux = aux_contextualizer.get_output_dim()

        # TODO(joelgrus): more sampled softmax configuration options, as needed.
        if num_samples is not None:
            self._lang1_softmax_loss = SampledSoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=self._forward_dim_aux,
                num_samples=num_samples,
                sparse=sparse_embeddings,
            )
            self._lang2_softmax_loss = SampledSoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=self._forward_dim_aux,
                num_samples=num_samples,
                sparse=sparse_embeddings,
            )
            self._cm_softmax_loss = SampledSoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=self._forward_dim,
                num_samples=num_samples,
                sparse=sparse_embeddings,
            )
        else:
            self._lang1_softmax_loss = _SoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=self._forward_dim_aux)
            self._lang2_softmax_loss = _SoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=self._forward_dim_aux)
            self._cm_loss = _SoftmaxLoss(num_words=vocab.get_vocab_size(),
                                         embedding_dim=self._forward_dim)

        # This buffer is now unused and exists only for backwards compatibility reasons.
        self.register_buffer("_last_average_loss", torch.zeros(1))

        self._lang1_perplexity = Perplexity()
        self._lang2_perplexity = Perplexity()
        self._cm_perplexity = Perplexity()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = lambda x: x

        if initializer is not None:
            initializer(self)
Пример #21
0
class GirNetLM(Model):
    """
    The ``LanguageModel`` applies a "contextualizing"
    ``Seq2SeqEncoder`` to uncontextualized embeddings, using a ``SoftmaxLoss``
    module (defined above) to compute the language modeling loss.
    should have "is_bidirectional()"

    If bidirectional is True,  the language model is trained to predict the next and
    previous tokens for each token in the input. In this case, the contextualizer must
    be bidirectional. If bidirectional is False, the language model is trained to only
    predict the next token for each token in the input; the contextualizer should also
    be unidirectional.

    If your language model is bidirectional, it is IMPORTANT that your bidirectional
    ``Seq2SeqEncoder`` contextualizer does not do any "peeking ahead". That is, for its
    forward direction it should only consider embeddings at previous timesteps, and for
    its backward direction only embeddings at subsequent timesteps. Similarly, if your
    language model is unidirectional, the unidirectional contextualizer should only
    consider embeddings at previous timesteps. If this condition is not met, your
    language model is cheating.

    Parameters
    ----------
    vocab: ``Vocabulary``
    text_field_embedder: ``TextFieldEmbedder``
        Used to embed the indexed tokens we get in ``forward``.
    contextualizer: ``Seq2SeqEncoder``
        Used to "contextualize" the embeddings. As described above,
        this encoder must not cheat by peeking ahead.
    dropout: ``float``, optional (default: None)
        If specified, dropout is applied to the contextualized embeddings before computation of
        the softmax. The contextualized embeddings themselves are returned without dropout.
    num_samples: ``int``, optional (default: None)
        If provided, the model will use ``SampledSoftmaxLoss``
        with the specified number of samples. Otherwise, it will use
        the full ``_SoftmaxLoss`` defined above.
    sparse_embeddings: ``bool``, optional (default: False)
        Passed on to ``SampledSoftmaxLoss`` if True.
    bidirectional: ``bool``, optional (default: False)
        Train a bidirectional language model, where the contextualizer
        is used to predict the next and previous token for each input token.
        This must match the bidirectionality of the contextualizer.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    aux_contextualizer : ``Seq2SeqEncoder``
    """
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        contextualizer: Seq2SeqEncoder,
        aux_contextualizer: Seq2SeqEncoder,
        dropout: float = None,
        num_samples: int = None,
        sparse_embeddings: bool = False,
        bidirectional: bool = False,
        initializer: InitializerApplicator = None,
        regularizer: Optional[RegularizerApplicator] = None,
    ) -> None:
        super().__init__(vocab, regularizer)
        self._text_field_embedder = text_field_embedder

        if contextualizer.is_bidirectional() is not bidirectional:
            raise ConfigurationError(
                "Bidirectionality of contextualizer must match bidirectionality of "
                "language model. "
                f"Contextualizer bidirectional: {contextualizer.is_bidirectional()}, "
                f"language model bidirectional: {bidirectional}")

        self._contextualizer_lang1 = aux_contextualizer
        self._contextualizer_lang2 = copy.deepcopy(aux_contextualizer)
        self._contextualizer = contextualizer

        self._bidirectional = bidirectional
        self._bidirectional_aux = aux_contextualizer.is_bidirectional()

        # The dimension for making predictions just in the forward
        # (or backward) direction.
        # main contextualizer forward dim
        if self._bidirectional:
            self._forward_dim = contextualizer.get_output_dim() // 2
        else:
            self._forward_dim = contextualizer.get_output_dim()

        # aux contextualizer forward dim
        if self._bidirectional_aux:
            self._forward_dim_aux = aux_contextualizer.get_output_dim() // 2
        else:
            self._forward_dim_aux = aux_contextualizer.get_output_dim()

        # TODO(joelgrus): more sampled softmax configuration options, as needed.
        if num_samples is not None:
            self._lang1_softmax_loss = SampledSoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=self._forward_dim_aux,
                num_samples=num_samples,
                sparse=sparse_embeddings,
            )
            self._lang2_softmax_loss = SampledSoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=self._forward_dim_aux,
                num_samples=num_samples,
                sparse=sparse_embeddings,
            )
            self._cm_softmax_loss = SampledSoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=self._forward_dim,
                num_samples=num_samples,
                sparse=sparse_embeddings,
            )
        else:
            self._lang1_softmax_loss = _SoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=self._forward_dim_aux)
            self._lang2_softmax_loss = _SoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=self._forward_dim_aux)
            self._cm_loss = _SoftmaxLoss(num_words=vocab.get_vocab_size(),
                                         embedding_dim=self._forward_dim)

        # This buffer is now unused and exists only for backwards compatibility reasons.
        self.register_buffer("_last_average_loss", torch.zeros(1))

        self._lang1_perplexity = Perplexity()
        self._lang2_perplexity = Perplexity()
        self._cm_perplexity = Perplexity()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = lambda x: x

        if initializer is not None:
            initializer(self)

    # SAFE
    def _get_target_token_embeddings(self, token_embeddings: torch.Tensor,
                                     mask: torch.Tensor,
                                     direction: int) -> torch.Tensor:
        # Need to shift the mask in the correct direction
        zero_col = token_embeddings.new_zeros(mask.size(0),
                                              1).to(dtype=torch.bool)
        if direction == 0:
            # forward direction, get token to right
            shifted_mask = torch.cat([zero_col, mask[:, 0:-1]], dim=1)
        else:
            shifted_mask = torch.cat([mask[:, 1:], zero_col], dim=1)
        return token_embeddings.masked_select(shifted_mask.unsqueeze(-1)).view(
            -1, self._forward_dim)

    def is_label_bidirectional(self, label):
        if label is 'lang1' or label is 'lang2':
            return self._bidirectional_aux
        elif label is 'cm':
            return self._bidirectional
        else:
            raise Exception(f"unknown label {label}")

    def _compute_loss(self,
                      lm_embeddings: torch.Tensor,
                      token_embeddings: torch.Tensor,
                      forward_targets: torch.Tensor,
                      backward_targets: torch.Tensor = None,
                      label="cm") -> Tuple[torch.Tensor, torch.Tensor]:
        # If bidirectional, lm_embeddings is shape (batch_size, timesteps, dim * 2)
        # If unidirectional, lm_embeddings is shape (batch_size, timesteps, dim)
        # forward_targets, backward_targets (None in the unidirectional case) are
        # shape (batch_size, timesteps) masked with 0
        if self.is_label_bidirectional(label):
            forward_embeddings, backward_embeddings = lm_embeddings.chunk(
                2, -1)
            backward_loss = self._loss_helper(1,
                                              backward_embeddings,
                                              backward_targets,
                                              token_embeddings,
                                              label=label)
        else:
            forward_embeddings = lm_embeddings
            backward_loss = None

        forward_loss = self._loss_helper(0,
                                         forward_embeddings,
                                         forward_targets,
                                         token_embeddings,
                                         label=label)
        return forward_loss, backward_loss

    def _loss_helper(self,
                     direction: int,
                     direction_embeddings: torch.Tensor,
                     direction_targets: torch.Tensor,
                     token_embeddings: torch.Tensor,
                     label="cm") -> Tuple[int, int]:
        mask = direction_targets > 0
        # we need to subtract 1 to undo the padding id since the softmax
        # does not include a padding dimension

        # shape (batch_size * timesteps, )
        non_masked_targets = direction_targets.masked_select(mask) - 1

        if label is "lang1" or label is "lang2":
            non_masked_embeddings = direction_embeddings.masked_select(
                mask.unsqueeze(-1)).view(-1, self._forward_dim_aux)
        else:
            # shape (batch_size * timesteps, embedding_dim)
            non_masked_embeddings = direction_embeddings.masked_select(
                mask.unsqueeze(-1)).view(-1, self._forward_dim)
        # note: need to return average loss across forward and backward
        # directions, but total sum loss across all batches.
        # Assuming batches include full sentences, forward and backward
        # directions have the same number of samples, so sum up loss
        # here then divide by 2 just below
        if label is "lang1":
            if not self._lang1_softmax_loss.tie_embeddings or not self._use_character_inputs:
                return self._lang1_softmax_loss(non_masked_embeddings,
                                                non_masked_targets)
        elif label is "lang2":
            if not self._lang2_softmax_loss.tie_embeddings or not self._use_character_inputs:
                return self._lang2_softmax_loss(non_masked_embeddings,
                                                non_masked_targets)
        elif label is "cm":
            if not self._cm_softmax_loss.tie_embeddings or not self._use_character_inputs:
                return self._cm_softmax_loss(non_masked_embeddings,
                                             non_masked_targets)

    def delete_softmax(self) -> None:
        """
        Remove the softmax weights. Useful for saving memory when calculating the loss
        is not necessary, e.g. in an embedder.
        """
        self._softmax_loss = None

    # UNSAFE
    def num_layers(self) -> int:
        """
        Returns the depth of this LM. That is, how many layers the contextualizer has plus one for
        the non-contextual layer.
        """
        if hasattr(self._contextualizer, "num_layers"):
            return self._contextualizer.num_layers + 1
        else:
            raise NotImplementedError(
                f"Contextualizer of type {type(self._contextualizer)} " +
                "does not report how many layers it has.")

    def forward(  # type: ignore
            self, lang1: Dict[str, torch.LongTensor],
            lang2: Dict[str, torch.LongTensor],
            cm: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]:
        """
        Computes the averaged forward (and backward, if language model is bidirectional)
        LM loss from the batch.

        Parameters
        ----------
        lang1: ``Dict[str, torch.LongTensor]``, required.
            The output of ``Batch.as_tensor_dict()`` for a batch of sentences. By convention,
            it's required to have at least a ``"tokens"`` entry that's the output of a
            ``SingleIdTokenIndexer``, which is used to compute the language model targets.
        lang2: ``Dict[str, torch.LongTensor]``, required.
            The output of ``Batch.as_tensor_dict()`` for a batch of sentences. By convention,
            it's required to have at least a ``"tokens"`` entry that's the output of a
            ``SingleIdTokenIndexer``, which is used to compute the language model targets.
        cm: ``Dict[str, torch.LongTensor]``, required.
            The output of ``Batch.as_tensor_dict()`` for a batch of sentences. By convention,
            it's required to have at least a ``"tokens"`` entry that's the output of a
            ``SingleIdTokenIndexer``, which is used to compute the language model targets.

        Returns
        -------
        Dict with keys:

        ``'loss'``: ``torch.Tensor``
            forward negative log likelihood, or the average of forward/backward
            if language model is bidirectional
        ``'forward_loss'``: ``torch.Tensor``
            forward direction negative log likelihood
        ``'backward_loss'``: ``torch.Tensor`` or ``None``
            backward direction negative log likelihood. If language model is not
            bidirectional, this is ``None``.
        ``'lm_embeddings'``: ``Union[torch.Tensor, List[torch.Tensor]]``
            (batch_size, timesteps, embed_dim) tensor of top layer contextual representations or
            list of all layers. No dropout applied.
        ``'noncontextual_token_embeddings'``: ``torch.Tensor``
            (batch_size, timesteps, token_embed_dim) tensor of bottom layer noncontextual
            representations
        ``'mask'``: ``torch.Tensor``
            (batch_size, timesteps) mask for the embeddings

        """

        # get text field mask for each input; safe operation
        lang1_mask = get_text_field_mask(lang1)
        lang2_mask = get_text_field_mask(lang2)
        cm_mask = get_text_field_mask(cm)

        # shape (batch_size, timesteps, embedding_size)
        lang1_embeddings = self._text_field_embedder(lang1)
        lang2_embeddings = self._text_field_embedder(lang2)
        cm_embeddings = self._text_field_embedder(cm)

        # Either the top layer or all layers.
        lang1_contextual_embeddings: Union[
            torch.Tensor, List[torch.Tensor]] = self._contextualizer_lang1(
                lang1_embeddings, lang1_mask)
        lang2_contextual_embeddings: Union[
            torch.Tensor, List[torch.Tensor]] = self._contextualizer_lang2(
                lang2_embeddings, lang2_mask)

        return_dict = {}

        lang1_dict = self._each_lang_lost(
            mask=lang1_mask,
            source=lang1,
            embeddings=lang1_embeddings,
            contextual_embeddings=lang1_contextual_embeddings,
            label='lang1')
        return_dict.update(lang1_dict)

        lang2_dict = self._each_lang_lost(
            mask=lang2_mask,
            source=lang2,
            embeddings=lang2_embeddings,
            contextual_embeddings=lang2_contextual_embeddings,
            label='lang2')
        return_dict.update(lang2_dict)

        #############
        # GIRNET STUFF
        #############

        # get lang1 and lang2 embedding of code_mixed data
        cm_lang1_contextual_embeddings: Union[
            torch.Tensor, List[torch.Tensor]] = self._contextualizer_lang1(
                cm_embeddings, cm_mask)
        cm_lang2_contextual_embeddings: Union[
            torch.Tensor, List[torch.Tensor]] = self._contextualizer_lang2(
                cm_embeddings, cm_mask)

        # MERGE aux representations
        if self._bidirectional_aux and self._bidirectional:
            # if both of them are bidirectiona;
            cm_lang1_contextual_embeddings_forward, cm_lang1_contextual_embeddings_backward = cm_lang1_contextual_embeddings.chunk(
                2, -1)
            cm_lang2_contextual_embeddings_forward, cm_lang2_contextual_embeddings_backward = cm_lang2_contextual_embeddings.chunk(
                2, -1)

            cm_cat_contextual_embeddings_forward = torch.cat([
                cm_lang1_contextual_embeddings_forward,
                cm_lang2_contextual_embeddings_forward
            ], -1)
            cm_cat_contextual_embeddings_backward = torch.cat([
                cm_lang1_contextual_embeddings_backward,
                cm_lang2_contextual_embeddings_backward
            ], -1)

            cm_cat_contextual_embeddings = torch.cat([
                cm_cat_contextual_embeddings_forward, cm_embeddings,
                cm_cat_contextual_embeddings_backward, cm_embeddings
            ], -1)
        elif not self._bidirectional_aux and not self._bidirectional:
            # if both of them are unidirectional
            cm_cat_contextual_embeddings = torch.cat([
                cm_lang1_contextual_embeddings, cm_lang2_contextual_embeddings,
                cm_embeddings
            ], -1)
        else:
            raise Exception(
                "contextualizer and aux_contextualizer should have same directionality"
            )

        # Run contextualizer on the merged representation of the input
        # a bidirectional contextualizer breaks the input into two parts: front and back(only on transformers)
        cm_contextual_embeddings: Union[
            torch.Tensor, List[torch.Tensor]] = self._contextualizer(
                cm_cat_contextual_embeddings, cm_mask)

        cm_dict = self._each_lang_lost(
            mask=cm_mask,
            source=cm,
            embeddings=cm_embeddings,
            contextual_embeddings=cm_contextual_embeddings,
            label='cm')
        return_dict.update(cm_dict)

        # If we have target tokens, calculate the loss.
        token_ids = cm.get("tokens")  # safe
        if token_ids is not None:
            average_loss = (lang1_dict['lang1_loss'] +
                            lang2_dict['lang2_loss'] +
                            (2 * cm_dict['cm_loss'])) / 4
            return_dict.update({"loss": average_loss})

        return return_dict

    def _each_lang_lost(self, mask, source: Dict[str,
                                                 torch.LongTensor], embeddings,
                        contextual_embeddings: torch.Tensor, label):

        return_dict = {}

        # If we have target tokens, calculate the loss.
        token_ids = source.get("tokens")  # safe
        if token_ids is not None:
            assert isinstance(contextual_embeddings, torch.Tensor)

            # Use token_ids to compute targets
            forward_targets = torch.zeros_like(token_ids)
            forward_targets[:, 0:-1] = token_ids[:, 1:]

            if self.is_label_bidirectional(label):
                backward_targets = torch.zeros_like(token_ids)
                backward_targets[:, 1:] = token_ids[:, 0:-1]
            else:
                backward_targets = None

            # add dropout
            contextual_embeddings_with_dropout = self._dropout(
                contextual_embeddings)

            # compute softmax loss
            forward_loss, backward_loss = self._compute_loss(
                contextual_embeddings_with_dropout,
                embeddings,
                forward_targets,
                backward_targets,
                label=label)

            num_targets = torch.sum((forward_targets > 0).long())
            if num_targets > 0:
                if self.is_label_bidirectional(label):
                    average_loss = 0.5 * (forward_loss +
                                          backward_loss) / num_targets.float()
                else:
                    average_loss = forward_loss / num_targets.float()
            else:
                average_loss = torch.tensor(0.0).to(forward_targets.device)

            if label is 'lang1':
                self._lang1_perplexity(average_loss)
            elif label is 'lang2':
                self._lang2_perplexity(average_loss)
            elif label is 'cm':
                self._cm_perplexity(average_loss)

            if num_targets > 0:
                return_dict.update({
                    f"{label}_loss":
                    average_loss,
                    f"{label}_forward_loss":
                    forward_loss / num_targets.float(),
                    f"{label}_batch_weight":
                    num_targets.float(),
                })
                if backward_loss is not None:
                    return_dict[
                        f"{label}_backward_loss"] = backward_loss / num_targets.float(
                        )
            else:
                # average_loss zero tensor, return it for all
                return_dict.update({
                    f"{label}_loss": average_loss,
                    "forward_loss": average_loss
                })
                if backward_loss is not None:
                    return_dict[f"{label}_backward_loss"] = average_loss

        if label is "cm":
            return_dict.update({
                "lm_embeddings": contextual_embeddings,
                "noncontextual_token_embeddings": embeddings,
                "mask": mask,
            })
        else:
            return_dict.update({
                f"{label}_lm_embeddings": contextual_embeddings,
                f"{label}_noncontextual_token_embeddings": embeddings,
                f"{label}_mask": mask,
            })
        return return_dict

    def get_metrics(self, reset: bool = False):
        return {
            "ppl_lang1": self._lang1_perplexity.get_metric(reset=reset),
            "ppl_lang2": self._lang2_perplexity.get_metric(reset=reset),
            "ppl_cm": self._cm_perplexity.get_metric(reset=reset)
        }
    def __init__(self,
                 vocab: Vocabulary,
                 max_decoding_steps: int,
                 decoder_net: DecoderNet,
                 target_embedder: Embedding,
                 loss_criterion: LossCriterion,
                
                 generation_batch_size: int = 200,
                 use_in_seq2seq_mode: bool = False,
                 target_namespace: str = "tokens",
                 beam_size: int = None,
                 scheduled_sampling_ratio: float = 0.0,
                 scheduled_sampling_k: int = 100,
                 scheduled_sampling_type: str = 'uniform',
                 rollin_mode: str = 'mixed',
                 rollout_mode: str = 'learned',

                 dropout: float = None,
                 start_token: str = START_SYMBOL,
                 end_token: str = END_SYMBOL,
                 num_decoder_layers: int = 1,
                 mask_pad_and_oov: bool = False,
                 tie_output_embedding: bool = False,

                 rollout_mixing_prob:float = 0.5,

                 use_bleu: bool = False,
                 use_hamming: bool = False,

                 sample_rollouts: bool = False,
                 beam_search_sampling_temperature: float = 1.,
                 top_k=0, 
                 top_p=0,
                 tensor_based_metric: Metric = None,
                 tensor_based_metric_mask: Metric = None,
                 token_based_metric: Metric = None,
                 eval_beam_size: int = 1,
                ) -> None:
        super().__init__(target_embedder)

        self.current_device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
        self._vocab = vocab
        self._seq2seq_mode = use_in_seq2seq_mode

        # Decodes the sequence of encoded hidden states into e new sequence of hidden states.
        self._max_decoding_steps = max_decoding_steps
        self._generation_batch_size = generation_batch_size
        self._decoder_net = decoder_net

        self._target_namespace = target_namespace

        # TODO #4 (Kushal): Maybe make them modules so that we can add more of these later.
        # TODO #8 #7 (Kushal): Rename "mixed" rollin mode to "scheduled sampling".
        self._rollin_mode = rollin_mode
        self._rollout_mode = rollout_mode

        self._scheduled_sampling_ratio = scheduled_sampling_ratio
        self._scheduled_sampling_k = scheduled_sampling_k
        self._scheduled_sampling_type = scheduled_sampling_type
        self._sample_rollouts = sample_rollouts
        self._mask_pad_and_oov = mask_pad_and_oov

        self._rollout_mixing_prob = rollout_mixing_prob

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self._vocab.get_token_index(start_token, self._target_namespace)
        self._end_index = self._vocab.get_token_index(end_token, self._target_namespace)

        self._padding_index = self._vocab.get_token_index(DEFAULT_PADDING_TOKEN, self._target_namespace)
        self._oov_index = self._vocab.get_token_index(DEFAULT_OOV_TOKEN, self._target_namespace)

        if self._mask_pad_and_oov:
            self._vocab_mask = torch.ones(self._vocab.get_vocab_size(self._target_namespace),
                                            device=self.current_device) \
                                    .scatter(0, torch.tensor([self._padding_index, self._oov_index, self._start_index],
                                                                device=self.current_device),
                                                0)
        if use_bleu:
            pad_index = self._vocab.get_token_index(self._vocab._padding_token, self._target_namespace)  # pylint: disable=protected-access
            self._bleu = BLEU(exclude_indices={pad_index, self._end_index, self._start_index})
        else:
            self._bleu = None

        if use_hamming:
            self._hamming = HammingLoss()
        else:
            self._hamming = None

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1

        # TODO(Kushal): Pass in the arguments for sampled. Also, make sure you do not sample in case of Seq2Seq models.
        self._beam_search = SampledBeamSearch(self._end_index, 
                                                max_steps=max_decoding_steps, 
                                                beam_size=beam_size, temperature=beam_search_sampling_temperature)

        self._num_classes = self._vocab.get_vocab_size(self._target_namespace)

        if self.target_embedder.get_output_dim() != self._decoder_net.target_embedding_dim:
            raise ConfigurationError(
                "Target Embedder output_dim doesn't match decoder module's input." + 
                    f"target_embedder_dim: {self.target_embedder.get_output_dim()}, " + 
                    f"decoder input dim: {self._decoder_net.target_embedding_dim}."
            )

        self._ss_ratio = Average()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = lambda x: x

        self.training_iteration = 0
        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_net.get_output_dim(), self._num_classes)

        if tie_output_embedding:
            if self._output_projection_layer.weight.shape != self.target_embedder.weight.shape:
                raise ConfigurationError(
                    f"Can't tie embeddings with output linear layer, due to shape mismatch. " + 
                    f"{self._output_projection_layer.weight.shape} and {self.target_embedder.weight.shape}"
                )
            self._output_projection_layer.weight = self.target_embedder.weight

        self._loss_criterion = loss_criterion

        self._top_k = top_k
        self._top_p = top_p
        self._eval_beam_size = eval_beam_size
        self._mle_loss = MaximumLikelihoodLossCriterion()
        self._perplexity = Perplexity()

        # These metrics will be updated during training and validation
        self._tensor_based_metric = tensor_based_metric
        self._token_based_metric = token_based_metric
        self._tensor_based_metric_mask = tensor_based_metric_mask
        
        self._decode_tokens = partial(decode_tokens, 
                                    vocab=self._vocab,
                                    start_index=self._start_index,
                                    end_index=self._end_index)
Пример #23
0
class NameGen(Model):
    """ The `NextTokenLM` embeds some input tokens, contextualizes them, ,then predicts the next work
        If `BeamSearch` is given, this model will predict a sequence of tokens
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 contextualizer: Seq2SeqEncoder = None,
                 dropout: float = 0.0,
                 num_samples: int = None,
                 sparse_embeddings: bool = False,
                 bidirectional: bool = False,
                 initializer=InitializerApplicator(),
                 **kwargs) -> None:
        super().__init__(vocab, **kwargs)
        self._text_field_embedder = text_field_embedder
        self._contextualizer = contextualizer
        self._bidirectional = bidirectional

        if self._bidirectional:
            self._forward_dim = contextualizer.get_output_dim() // 2
        else:
            self._forward_dim = contextualizer.get_output_dim()

        self._softmax_loss = SoftmaxLoss(num_words=vocab.get_vocab_size(),
                                         embedding_dim=self._forward_dim)

        self._perplexity = Perplexity()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = lambda x: x

        if initializer is not None:
            initializer(self)

    def _get_target_token_embeddings(self, token_embeddings: torch.Tensor,
                                     mask: torch.BoolTensor,
                                     direction: int) -> torch.Tensor:
        # Need to shift the mask in the correct direction
        zero_col = token_embeddings.new_zeros(mask.size(0),
                                              1).to(dtype=torch.bool)
        if direction == 0:
            # forward direction, get token to right
            shifted_mask = torch.cat([zero_col, mask[:, 0:-1]], dim=1)
        else:
            shifted_mask = torch.cat([mask[:, 1:], zero_col], dim=1)
        return token_embeddings.masked_select(shifted_mask.unsqueeze(-1)).view(
            -1, self._forward_dim)

    def _compute_loss(
        self,
        lm_embeddings: torch.Tensor,
        token_embeddings: torch.Tensor,
        forward_targets: torch.Tensor,
        backward_targets: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # If bidirectional, lm_embeddings is shape (batch_size, timesteps, dim * 2)
        # If unidirectional, lm_embeddings is shape (batch_size, timesteps, dim)
        # forward_targets, backward_targets (None in the unidirectional case) are
        # shape (batch_size, timesteps) masked with 0
        if self._bidirectional:
            forward_embeddings, backward_embeddings = lm_embeddings.chunk(
                2, -1)
            backward_loss = self._loss_helper(1, backward_embeddings,
                                              backward_targets,
                                              token_embeddings)
        else:
            forward_embeddings = lm_embeddings
            backward_loss = None

        forward_loss = self._loss_helper(0, forward_embeddings,
                                         forward_targets, token_embeddings)
        return forward_loss, backward_loss

    def _loss_helper(
        self,
        direction: int,
        direction_embeddings: torch.Tensor,
        direction_targets: torch.Tensor,
        token_embeddings: torch.Tensor,
    ) -> Tuple[int, int]:
        mask = direction_targets > 0
        # we need to subtract 1 to undo the padding id since the softmax
        # does not include a padding dimension

        # shape (batch_size * timesteps, )
        non_masked_targets = direction_targets.masked_select(mask) - 1

        # shape (batch_size * timesteps, embedding_dim)
        non_masked_embeddings = direction_embeddings.masked_select(
            mask.unsqueeze(-1)).view(-1, self._forward_dim)
        # note: need to return average loss across forward and backward
        # directions, but total sum loss across all batches.
        # Assuming batches include full sentences, forward and backward
        # directions have the same number of samples, so sum up loss
        # here then divide by 2 just below
        if not self._softmax_loss.tie_embeddings or not self._use_character_inputs:
            return self._softmax_loss(non_masked_embeddings,
                                      non_masked_targets)
        else:
            # we also need the token embeddings corresponding to the
            # the targets
            raise NotImplementedError(
                "This requires SampledSoftmaxLoss, which isn't implemented yet."
            )

            non_masked_token_embeddings = self._get_target_token_embeddings(
                token_embeddings, mask, direction)
            return self._softmax(non_masked_embeddings, non_masked_targets,
                                 non_masked_token_embeddings)

    def delete_softmax(self) -> None:
        """
        Remove the softmax weights. Useful for saving memory when calculating the loss
        is not necessary, e.g. in an embedder.
        """
        self._softmax_loss = None

    def num_layers(self) -> int:
        """
        Returns the depth of this LM. That is, how many layers the contextualizer has plus one for
        the non-contextual layer.
        """
        if hasattr(self._contextualizer, "num_layers"):
            return self._contextualizer.num_layers + 1
        else:
            raise NotImplementedError(
                f"Contextualizer of type {type(self._contextualizer)} " +
                "does not report how many layers it has.")

    def forward(
            self, tokens: TextFieldTensors
    ) -> Dict[str, torch.Tensor]:  # type: ignore
        source = tokens
        """
        Computes the averaged forward (and backward, if language model is bidirectional)
        LM loss from the batch.
        # Parameters
        source : `TextFieldTensors`, required.
            The output of `Batch.as_tensor_dict()` for a batch of sentences. By convention,
            it's required to have at least a `"tokens"` entry that's the output of a
            `SingleIdTokenIndexer`, which is used to compute the language model targets.
        # Returns
        Dict with keys:
        `'loss'` : `torch.Tensor`
            forward negative log likelihood, or the average of forward/backward
            if language model is bidirectional
        `'forward_loss'` : `torch.Tensor`
            forward direction negative log likelihood
        `'backward_loss'` : `torch.Tensor` or `None`
            backward direction negative log likelihood. If language model is not
            bidirectional, this is `None`.
        `'lm_embeddings'` : `Union[torch.Tensor, List[torch.Tensor]]`
            (batch_size, timesteps, embed_dim) tensor of top layer contextual representations or
            list of all layers. No dropout applied.
        `'noncontextual_token_embeddings'` : `torch.Tensor`
            (batch_size, timesteps, token_embed_dim) tensor of bottom layer noncontextual
            representations
        `'mask'` : `torch.BoolTensor`
            (batch_size, timesteps) mask for the embeddings
        """

        mask = get_text_field_mask(source)

        # shape (batch_size, timesteps, embedding_size)
        embeddings = self._text_field_embedder(source)

        # Either the top layer or all layers.
        contextual_embeddings: Union[
            torch.Tensor,
            List[torch.Tensor]] = self._contextualizer(embeddings, mask)

        return_dict = {}

        # If we have target tokens, calculate the loss.
        token_id_dict = source.get("tokens")
        if token_id_dict is not None:
            token_ids = token_id_dict["tokens"]
            assert isinstance(contextual_embeddings, torch.Tensor)

            # Use token_ids to compute targets
            forward_targets = torch.zeros_like(token_ids)
            forward_targets[:, 0:-1] = token_ids[:, 1:]

            if self._bidirectional:
                backward_targets = torch.zeros_like(token_ids)
                backward_targets[:, 1:] = token_ids[:, 0:-1]
            else:
                backward_targets = None

            # add dropout
            contextual_embeddings_with_dropout = self._dropout(
                contextual_embeddings)

            # compute softmax loss
            forward_loss, backward_loss = self._compute_loss(
                contextual_embeddings_with_dropout, embeddings,
                forward_targets, backward_targets)

            num_targets = torch.sum((forward_targets > 0).long())
            if num_targets > 0:
                if self._bidirectional:
                    average_loss = 0.5 * (forward_loss +
                                          backward_loss) / num_targets.float()
                else:
                    average_loss = forward_loss / num_targets.float()
            else:
                average_loss = torch.tensor(0.0).to(forward_targets.device)

            self._perplexity(average_loss)

            if num_targets > 0:
                return_dict.update({
                    "loss":
                    average_loss,
                    "forward_loss":
                    forward_loss / num_targets.float(),
                    "batch_weight":
                    num_targets.float(),
                })
                if backward_loss is not None:
                    return_dict[
                        "backward_loss"] = backward_loss / num_targets.float()
            else:
                # average_loss zero tensor, return it for all
                return_dict.update({
                    "loss": average_loss,
                    "forward_loss": average_loss
                })
                if backward_loss is not None:
                    return_dict["backward_loss"] = average_loss

        return_dict.update({
            # Note: These embeddings do not have dropout applied.
            "lm_embeddings": contextual_embeddings,
            "noncontextual_token_embeddings": embeddings,
            "mask": mask,
        })

        return return_dict

    def get_metrics(self, reset: bool = False):
        return {"perplexity": self._perplexity.get_metric(reset=reset)}