Пример #1
0
    def forward(
            self, source: TextFieldTensors
    ) -> Dict[str, torch.Tensor]:  # type: ignore
        """
        Computes the averaged forward (and backward, if language model is bidirectional)
        LM loss from the batch.

        # Parameters

        source : `TextFieldTensors`, required.
            The output of `Batch.as_tensor_dict()` for a batch of sentences. By convention,
            it's required to have at least a `"tokens"` entry that's the output of a
            `SingleIdTokenIndexer`, which is used to compute the language model targets.

        # Returns

        Dict with keys:

        `'loss'` : `torch.Tensor`
            forward negative log likelihood, or the average of forward/backward
            if language model is bidirectional
        `'forward_loss'` : `torch.Tensor`
            forward direction negative log likelihood
        `'backward_loss'` : `torch.Tensor` or `None`
            backward direction negative log likelihood. If language model is not
            bidirectional, this is `None`.
        `'lm_embeddings'` : `Union[torch.Tensor, List[torch.Tensor]]`
            (batch_size, timesteps, embed_dim) tensor of top layer contextual representations or
            list of all layers. No dropout applied.
        `'noncontextual_token_embeddings'` : `torch.Tensor`
            (batch_size, timesteps, token_embed_dim) tensor of bottom layer noncontextual
            representations
        `'mask'` : `torch.BoolTensor`
            (batch_size, timesteps) mask for the embeddings
        """

        mask = get_text_field_mask(source)

        # shape (batch_size, timesteps, embedding_size)
        embeddings = self._text_field_embedder(source)

        # Either the top layer or all layers.
        contextual_embeddings: Union[
            torch.Tensor,
            List[torch.Tensor]] = self._contextualizer(embeddings, mask)

        return_dict = {}

        # If we have target tokens, calculate the loss.
        token_id_dict = source.get("tokens")
        if token_id_dict is not None:
            token_ids = token_id_dict["tokens"]
            assert isinstance(contextual_embeddings, torch.Tensor)

            # Use token_ids to compute targets
            forward_targets = torch.zeros_like(token_ids)
            forward_targets[:, 0:-1] = token_ids[:, 1:]

            if self._bidirectional:
                backward_targets = torch.zeros_like(token_ids)
                backward_targets[:, 1:] = token_ids[:, 0:-1]
            else:
                backward_targets = None

            # add dropout
            contextual_embeddings_with_dropout = self._dropout(
                contextual_embeddings)

            # compute softmax loss
            forward_loss, backward_loss = self._compute_loss(
                contextual_embeddings_with_dropout, embeddings,
                forward_targets, backward_targets)

            num_targets = torch.sum((forward_targets > 0).long())
            if num_targets > 0:
                if self._bidirectional:
                    average_loss = 0.5 * (forward_loss +
                                          backward_loss) / num_targets.float()
                else:
                    average_loss = forward_loss / num_targets.float()
            else:
                average_loss = torch.tensor(0.0).to(forward_targets.device)

            self._perplexity(average_loss)

            if num_targets > 0:
                return_dict.update({
                    "loss":
                    average_loss,
                    "forward_loss":
                    forward_loss / num_targets.float(),
                    "batch_weight":
                    num_targets.float(),
                })
                if backward_loss is not None:
                    return_dict[
                        "backward_loss"] = backward_loss / num_targets.float()
            else:
                # average_loss zero tensor, return it for all
                return_dict.update({
                    "loss": average_loss,
                    "forward_loss": average_loss
                })
                if backward_loss is not None:
                    return_dict["backward_loss"] = average_loss

        return_dict.update({
            # Note: These embeddings do not have dropout applied.
            "lm_embeddings": contextual_embeddings,
            "noncontextual_token_embeddings": embeddings,
            "mask": mask,
        })

        return return_dict
Пример #2
0
    def forward(self, transactions: TextFieldTensors,
                **kwargs) -> Dict[str, torch.Tensor]:

        mask = get_text_field_mask(transactions)

        # shape (batch_size, timesteps, embedding_size)
        embeddings = self._text_field_embedder(transactions)

        # Either the top layer or all layers.
        contextual_embeddings: Union[
            torch.Tensor,
            List[torch.Tensor]] = self._contextualizer(embeddings, mask)

        return_dict = {}

        # If we have target transactions, calculate the loss.
        token_id_dict = transactions.get("tokens")
        if token_id_dict is not None:
            token_ids = token_id_dict["tokens"]
            assert isinstance(contextual_embeddings, torch.Tensor)

            # Use token_ids to compute targets
            forward_targets = torch.zeros_like(token_ids)
            forward_targets[:, 0:-1] = token_ids[:, 1:]

            if self._bidirectional:
                backward_targets = torch.zeros_like(token_ids)
                backward_targets[:, 1:] = token_ids[:, 0:-1]
            else:
                backward_targets = None

            # add dropout
            contextual_embeddings_with_dropout = self._dropout(
                contextual_embeddings)

            # compute softmax loss
            forward_loss, backward_loss = self._compute_loss(
                contextual_embeddings_with_dropout,
                embeddings,
                forward_targets,
                backward_targets,
            )

            num_targets = torch.sum((forward_targets > 0).long())
            if num_targets > 0:
                if self._bidirectional:
                    average_loss = 0.5 * (forward_loss +
                                          backward_loss) / num_targets.float()
                else:
                    average_loss = forward_loss / num_targets.float()
            else:
                average_loss = torch.tensor(0.0).to(forward_targets.device)

            self._perplexity(average_loss)

            if num_targets > 0:
                return_dict.update({
                    "loss":
                    average_loss,
                    "forward_loss":
                    forward_loss / num_targets.float(),
                    "batch_weight":
                    num_targets.float(),
                })
                if backward_loss is not None:
                    return_dict[
                        "backward_loss"] = backward_loss / num_targets.float()
            else:
                # average_loss zero tensor, return it for all
                return_dict.update({
                    "loss": average_loss,
                    "forward_loss": average_loss
                })
                if backward_loss is not None:
                    return_dict["backward_loss"] = average_loss

        return_dict.update({
            # Note: These embeddings do not have dropout applied.
            "lm_embeddings": contextual_embeddings,
            "noncontextual_token_embeddings": embeddings,
            "mask": mask,
        })

        return return_dict