Esempio n. 1
0
    def forward(
        self, text_field_input: TextFieldTensors, num_wrapping_dims: int = 0, **kwargs
    ) -> torch.Tensor:
        if self._token_embedders.keys() != text_field_input.keys():
            message = "Mismatched token keys: %s and %s" % (
                str(self._token_embedders.keys()),
                str(text_field_input.keys()),
            )
            raise ConfigurationError(message)

        embedded_representations = []
        for key in self._ordered_embedder_keys:
            embedder = getattr(self, "token_embedder_{}".format(key))
            forward_params_values = {}
            missing_tensor_args = set()
                if param in kwargs:
                    forward_params_values[param] = kwargs[param]
                else:

            for _ in range(num_wrapping_dims):
                embedder = TimeDistributed(embedder)

            tensors: Dict[str, torch.Tensor] = text_field_input[key]
            if len(tensors) == 1 and len(missing_tensor_args) == 1:
                token_vectors = embedder(list(tensors.values())[0], **forward_params_values)
            else:
                token_vectors = embedder(**tensors, **forward_params_values)
            if token_vectors is not None:
Esempio n. 2
0
    def forward(self, target_tokens: TextFieldTensors, **kwargs):

        if self.training:
            target_tokens = target_tokens[self.target_token_namespace][
                self.target_token_namespace]
            batch_size = target_tokens.size(0)
            sampling_output = self.sequence_generator.sampling(
                batch_size=batch_size, **kwargs)
            logits = sampling_output["logits"]
            predicted_tokens = sampling_output[
                "predicted_tokens"]  # shape: (batch_size, sequence_length)
            log_probs = F.log_softmax(
                logits,
                dim=2)  # shape: (batch_size, sequence_length, vocab_size)

            # shape: (batch_size, sequence_length)
            log_prob_for_predicted_tokens = log_probs.gather(
                dim=2, index=predicted_tokens.unsqueeze(2)).squeeze(2)

            mask = get_target_mask(
                predicted_tokens,
                end_index=self.sequence_generator.target_end_index)
            log_prob_for_predicted_tokens *= mask

            # shape: (batch_size, )
            log_prob_per_sequence = log_prob_for_predicted_tokens.sum(dim=1)

            # shape: (batch_size, )
            reward = self.reward_function(predicted_tokens, target_tokens)
            self.metrics["average_raw_reward"](reward)
            self.metrics["raw_reward_variance"](reward)

            if self.self_critic:
                self.sequence_generator.eval()
                start_tokens = target_tokens.new_full(
                    size=(batch_size, ),
                    fill_value=self.sequence_generator.target_start_index,
                    dtype=torch.long)
                greedy_output = self.sequence_generator(
                    start_tokens=start_tokens, **kwargs)
                self.sequence_generator.train()
                baseline_reward = self.reward_function(
                    greedy_output["predicted_tokens"], target_tokens)
            else:
                baseline_reward = self.metrics["average_reward"].get_metric(
                    reset=False)

            reward -= baseline_reward
            self.metrics["average_baselined_reward"](reward)
            self.metrics["baselined_reward_variance"](reward)

            loss = -(reward * log_prob_per_sequence).sum()

            return {"loss": loss, "predicted_tokens": predicted_tokens}
        else:
            return self.sequence_generator(target_tokens=target_tokens,
                                           **kwargs)
    def forward(self,
                text_field_input: TextFieldTensors,
                num_wrapping_dims: int = 0,
                **kwargs) -> torch.Tensor:
        if sorted(self._token_embedders.keys()) != sorted(
                text_field_input.keys()):
            message = "Mismatched token keys: %s and %s" % (
                str(self._token_embedders.keys()),
                str(text_field_input.keys()),
            )
            embedder_keys = set(self._token_embedders.keys())
            input_keys = set(text_field_input.keys())
            if embedder_keys > input_keys and all(
                    isinstance(embedder, EmptyEmbedder)
                    for name, embedder in self._token_embedders.items()
                    if name in embedder_keys - input_keys):
                # Allow extra embedders that are only in the token embedders (but not input) and are empty to pass
                # config check
                pass
            else:
                raise ConfigurationError(message)

        embedded_representations = []
        for key in self._ordered_embedder_keys:
            # Note: need to use getattr here so that the pytorch voodoo
            # with submodules works with multiple GPUs.
            embedder = getattr(self, "token_embedder_{}".format(key))
            if isinstance(embedder, EmptyEmbedder):
                # Skip empty embedders
                continue
            forward_params = inspect.signature(embedder.forward).parameters
            forward_params_values = {}
            missing_tensor_args = set()
            for param in forward_params.keys():
                if param in kwargs:
                    forward_params_values[param] = kwargs[param]
                else:
                    missing_tensor_args.add(param)

            for _ in range(num_wrapping_dims):
                embedder = TimeDistributed(embedder)

            tensors: Dict[str, torch.Tensor] = text_field_input[key]
            if len(tensors) == 1 and len(missing_tensor_args) == 1:
                # If there's only one tensor argument to the embedder, and we just have one tensor to
                # embed, we can just pass in that tensor, without requiring a name match.
                token_vectors = embedder(
                    list(tensors.values())[0], **forward_params_values)
            else:
                # If there are multiple tensor arguments, we have to require matching names from the
                # TokenIndexer.  I don't think there's an easy way around that.
                token_vectors = embedder(**tensors, **forward_params_values)
            if token_vectors is not None:
                # To handle some very rare use cases, we allow the return value of the embedder to
                # be None; we just skip it in that case.
                embedded_representations.append(token_vectors)
        return torch.cat(embedded_representations, dim=-1)
Esempio n. 4
0
    def forward(self,
                text_field_input: TextFieldTensors,
                augment: int,
                difficulty_step: int,
                num_wrapping_dims: int = 0,
                **kwargs) -> torch.Tensor:
        if self._token_embedders.keys() != text_field_input.keys():
            message = "Mismatched token keys: %s and %s" % (
                str(self._token_embedders.keys()),
                str(text_field_input.keys()),
            )
            raise ConfigurationError(message)

        embedded_representations = []
        for key in self._ordered_embedder_keys:
            # Note: need to use getattr here so that the pytorch voodoo
            # with submodules works with multiple GPUs.
            embedder = getattr(self, "token_embedder_{}".format(key))
            forward_params = inspect.signature(embedder.forward).parameters
            forward_params_values = {}
            missing_tensor_args = set()
            for param in forward_params.keys():
                if param in kwargs:
                    forward_params_values[param] = kwargs[param]
                else:
                    missing_tensor_args.add(param)

            for _ in range(num_wrapping_dims):
                embedder = TimeDistributed(embedder)

            tensors: Dict[str, torch.Tensor] = text_field_input[key]
            if len(tensors) == 1 and len(missing_tensor_args) == 1:
                # If there's only one tensor argument to the embedder, and we just have one tensor
                # to embed, we can just pass in that tensor, without requiring a name match.
                masked_lm_loss, token_vectors = embedder(
                    augment, difficulty_step,
                    list(tensors.values())[0], **forward_params_values)
            else:
                # If there are multiple tensor arguments, we have to require matching names from
                # the TokenIndexer. I don't think there's an easy way around that.
                masked_lm_loss, token_vectors = embedder(
                    augment, difficulty_step, **tensors,
                    **forward_params_values)
            if token_vectors is not None:
                # To handle some very rare use cases, we allow the return value of the embedder to
                # be None; we just skip it in that case.
                embedded_representations.append(token_vectors)
        return masked_lm_loss, torch.cat(embedded_representations, dim=-1)
Esempio n. 5
0
def elmo_input_reshape(inputs: TextFieldTensors, batch_size: int,
                       number_targets: int,
                       batch_size_num_targets: int) -> TextFieldTensors:
    '''
    NOTE: This does not work for the hugginface transformers as when they are 
    processed by the token indexers they produce additional key other than 
    token ids such as mask ids and segment ids that also need handling, of 
    which we have not had time to handle this yet. A way around this, which 
    would be more appropriate, would be to use `target_sequences` like in the 
    `InContext` model, to generate contextualised targets from the context rather 
    than using the target words as is without context.

    :param inputs: The token indexer dictionary where the keys state the token 
                   indexer and the values are the Tensors that are of shape 
                   (Batch Size, Number Targets, Sequence Length)
    :param batch_size: The Batch Size
    :param number_targets: The max number of targets in the batch
    :param batch_size_num_targets: Batch Size * number of targets
    :returns: If the inputs contains a `elmo` or 'token_characters' key it will 
              reshape all the keys values into shape 
              (Batch Size * Number Targets, Sequence Length) so that it can be 
              processed by the ELMO or character embedder/encoder. 
    '''
    if 'elmo' in inputs or 'token_characters' in inputs:
        temp_inputs: TextFieldTensors = defaultdict(dict)
        for key, inner_key_value in inputs.items():
            for inner_key, value in inner_key_value.items():
                temp_value = value.view(batch_size_num_targets,
                                        *value.shape[2:])
                temp_inputs[key][inner_key] = temp_value
        return dict(temp_inputs)
    else:
        return inputs
Esempio n. 6
0
    def mask_tokens(self, inputs: TextFieldTensors) -> Tuple[TextFieldTensors, TextFieldTensors]:

        masked_inputs = dict()
        masked_targets = dict()
        for text_field_name, text_field in inputs.items():
            masked_inputs[text_field_name] = dict()
            masked_targets[text_field_name] = dict()
            for key, tokens in text_field.items():
                labels = tokens.clone()

                indices_masked = torch.bernoulli(
                    torch.full(labels.shape, self.mask_probability, device=tokens.device)
                ).bool()
                tokens[indices_masked] = self.mask_idx

                indices_random = (
                    torch.bernoulli(torch.full(labels.shape, self.replace_probability, device=tokens.device)).bool()
                    & ~indices_masked
                )
                random_tokens = torch.randint(
                    low=1, high=self.vocab_size, size=labels.shape, dtype=torch.long, device=tokens.device,
                )
                tokens[indices_random] = random_tokens[indices_random]

                masked_inputs[text_field_name][key] = tokens
                masked_targets[text_field_name][key] = labels
        return masked_inputs, masked_targets
    def get_step_state(self,
                       inputs: TextFieldTensors) -> Dict[str, torch.Tensor]:
        """
        Create a `state` dictionary for `BeamSearch` from the `TextFieldTensors` inputs
        to the `NextTokenLm` model.

        By default this assumes the `TextFieldTensors` has a single `TokenEmbedder`,
        and just "flattens" the `TextFieldTensors` by returning the `TokenEmbedder`
        sub-dictionary.

        If you have `TextFieldTensors` with more than one `TokenEmbedder` sub-dictionary,
        you'll need to override this class.
        """
        assert len(inputs) == 1, (
            "'get_step_state()' assumes a single token embedder by default, "
            "you'll need to override this method to handle more than one")

        key = list(inputs.keys())[0]

        # We can't just `return inputs[key]` because we might want to modify the state
        # dictionary (add or remove fields) without accidentally modifying the inputs
        # dictionary.
        return {k: v for (k, v) in inputs[key].items()}
    def forward(
        self,  # type: ignore
        tokens: TextFieldTensors,
        label: torch.LongTensor = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        tokens : TextFieldTensors, required
            The output of `TextField.as_array()`.
        label : torch.LongTensor, optional (default = None)
            A variable representing the label for each instance in the batch.
        # Returns

        An output dictionary consisting of:
        class_probabilities : torch.FloatTensor
            A tensor of shape `(batch_size, num_classes)` representing a
            distribution over the label classes for each instance.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        text_mask = util.get_text_field_mask(tokens)
        # Pop elmo tokens, since elmo embedder should not be present.
        elmo_tokens = tokens.pop("elmo", None)
        if tokens:
            embedded_text = self._text_field_embedder(tokens)
        else:
            # only using "elmo" for input
            embedded_text = None

        # Add the "elmo" key back to "tokens" if not None, since the tests and the
        # subsequent training epochs rely not being modified during forward()
        if elmo_tokens is not None:
            tokens["elmo"] = elmo_tokens

        # Create ELMo embeddings if applicable
        if self._elmo:
            if elmo_tokens is not None:
                elmo_representations = self._elmo(
                    elmo_tokens["tokens"])["elmo_representations"]
                # Pop from the end is more performant with list
                if self._use_integrator_output_elmo:
                    integrator_output_elmo = elmo_representations.pop()
                if self._use_input_elmo:
                    input_elmo = elmo_representations.pop()
                assert not elmo_representations
            else:
                raise ConfigurationError(
                    "Model was built to use Elmo, but input text is not tokenized for Elmo."
                )

        if self._use_input_elmo:
            if embedded_text is not None:
                embedded_text = torch.cat([embedded_text, input_elmo], dim=-1)
            else:
                embedded_text = input_elmo

        dropped_embedded_text = self._embedding_dropout(embedded_text)
        pre_encoded_text = self._pre_encode_feedforward(dropped_embedded_text)
        encoded_tokens = self._encoder(pre_encoded_text, text_mask)

        # Compute biattention. This is a special case since the inputs are the same.
        attention_logits = encoded_tokens.bmm(
            encoded_tokens.permute(0, 2, 1).contiguous())
        attention_weights = util.masked_softmax(attention_logits, text_mask)
        encoded_text = util.weighted_sum(encoded_tokens, attention_weights)

        # Build the input to the integrator
        integrator_input = torch.cat([
            encoded_tokens, encoded_tokens - encoded_text,
            encoded_tokens * encoded_text
        ], 2)
        integrated_encodings = self._integrator(integrator_input, text_mask)

        # Concatenate ELMo representations to integrated_encodings if specified
        if self._use_integrator_output_elmo:
            integrated_encodings = torch.cat(
                [integrated_encodings, integrator_output_elmo], dim=-1)

        # Simple Pooling layers
        max_masked_integrated_encodings = util.replace_masked_values(
            integrated_encodings, text_mask.unsqueeze(2), -1e7)
        max_pool = torch.max(max_masked_integrated_encodings, 1)[0]
        min_masked_integrated_encodings = util.replace_masked_values(
            integrated_encodings, text_mask.unsqueeze(2), +1e7)
        min_pool = torch.min(min_masked_integrated_encodings, 1)[0]
        mean_pool = torch.sum(integrated_encodings, 1) / torch.sum(
            text_mask, 1, keepdim=True)

        # Self-attentive pooling layer
        # Run through linear projection. Shape: (batch_size, sequence length, 1)
        # Then remove the last dimension to get the proper attention shape (batch_size, sequence length).
        self_attentive_logits = self._self_attentive_pooling_projection(
            integrated_encodings).squeeze(2)
        self_weights = util.masked_softmax(self_attentive_logits, text_mask)
        self_attentive_pool = util.weighted_sum(integrated_encodings,
                                                self_weights)

        pooled_representations = torch.cat(
            [max_pool, min_pool, mean_pool, self_attentive_pool], 1)
        pooled_representations_dropped = self._integrator_dropout(
            pooled_representations)

        logits = self._output_layer(pooled_representations_dropped)
        class_probabilities = F.softmax(logits, dim=-1)

        output_dict = {
            "logits": logits,
            "class_probabilities": class_probabilities
        }
        if label is not None:
            loss = self.loss(logits, label)
            for metric in self.metrics.values():
                metric(logits, label)
            output_dict["loss"] = loss

        return output_dict
Esempio n. 9
0
    def forward(
            self, source: TextFieldTensors
    ) -> Dict[str, torch.Tensor]:  # type: ignore
        """
        Computes the averaged forward (and backward, if language model is bidirectional)
        LM loss from the batch.

        # Parameters

        source : `TextFieldTensors`, required.
            The output of `Batch.as_tensor_dict()` for a batch of sentences. By convention,
            it's required to have at least a `"tokens"` entry that's the output of a
            `SingleIdTokenIndexer`, which is used to compute the language model targets.

        # Returns

        Dict with keys:

        `'loss'` : `torch.Tensor`
            forward negative log likelihood, or the average of forward/backward
            if language model is bidirectional
        `'forward_loss'` : `torch.Tensor`
            forward direction negative log likelihood
        `'backward_loss'` : `torch.Tensor` or `None`
            backward direction negative log likelihood. If language model is not
            bidirectional, this is `None`.
        `'lm_embeddings'` : `Union[torch.Tensor, List[torch.Tensor]]`
            (batch_size, timesteps, embed_dim) tensor of top layer contextual representations or
            list of all layers. No dropout applied.
        `'noncontextual_token_embeddings'` : `torch.Tensor`
            (batch_size, timesteps, token_embed_dim) tensor of bottom layer noncontextual
            representations
        `'mask'` : `torch.BoolTensor`
            (batch_size, timesteps) mask for the embeddings
        """

        mask = get_text_field_mask(source)

        # shape (batch_size, timesteps, embedding_size)
        embeddings = self._text_field_embedder(source)

        # Either the top layer or all layers.
        contextual_embeddings: Union[
            torch.Tensor,
            List[torch.Tensor]] = self._contextualizer(embeddings, mask)

        return_dict = {}

        # If we have target tokens, calculate the loss.
        token_id_dict = source.get("tokens")
        if token_id_dict is not None:
            token_ids = token_id_dict["tokens"]
            assert isinstance(contextual_embeddings, torch.Tensor)

            # Use token_ids to compute targets
            forward_targets = torch.zeros_like(token_ids)
            forward_targets[:, 0:-1] = token_ids[:, 1:]

            if self._bidirectional:
                backward_targets = torch.zeros_like(token_ids)
                backward_targets[:, 1:] = token_ids[:, 0:-1]
            else:
                backward_targets = None

            # add dropout
            contextual_embeddings_with_dropout = self._dropout(
                contextual_embeddings)

            # compute softmax loss
            forward_loss, backward_loss = self._compute_loss(
                contextual_embeddings_with_dropout, embeddings,
                forward_targets, backward_targets)

            num_targets = torch.sum((forward_targets > 0).long())
            if num_targets > 0:
                if self._bidirectional:
                    average_loss = 0.5 * (forward_loss +
                                          backward_loss) / num_targets.float()
                else:
                    average_loss = forward_loss / num_targets.float()
            else:
                average_loss = torch.tensor(0.0).to(forward_targets.device)

            self._perplexity(average_loss)

            if num_targets > 0:
                return_dict.update({
                    "loss":
                    average_loss,
                    "forward_loss":
                    forward_loss / num_targets.float(),
                    "batch_weight":
                    num_targets.float(),
                })
                if backward_loss is not None:
                    return_dict[
                        "backward_loss"] = backward_loss / num_targets.float()
            else:
                # average_loss zero tensor, return it for all
                return_dict.update({
                    "loss": average_loss,
                    "forward_loss": average_loss
                })
                if backward_loss is not None:
                    return_dict["backward_loss"] = average_loss

        return_dict.update({
            # Note: These embeddings do not have dropout applied.
            "lm_embeddings": contextual_embeddings,
            "noncontextual_token_embeddings": embeddings,
            "mask": mask,
        })

        return return_dict
Esempio n. 10
0
    def forward(
            self,
            context_ids: TextFieldTensors,
            query_ids: TextFieldTensors,
            context_lens: torch.Tensor,
            query_lens: torch.Tensor,
            mask_label: Optional[torch.Tensor] = None,
            cls_label: Optional[torch.Tensor] = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # concat the context and query to the encoder
        # get the indexers first
        indexers = context_ids.keys()
        dialogue_ids = {}

        # 获取context和query的长度
        context_len = torch.max(context_lens).item()
        query_len = torch.max(query_lens).item()

        # [B, _len]
        context_mask = get_mask_from_sequence_lengths(context_lens,
                                                      context_len)
        query_mask = get_mask_from_sequence_lengths(query_lens, query_len)
        for indexer in indexers:
            # get the various variables of context and query
            dialogue_ids[indexer] = {}
            for key in context_ids[indexer].keys():
                context = context_ids[indexer][key]
                query = query_ids[indexer][key]
                # concat the context and query in the length dim
                dialogue = torch.cat([context, query], dim=1)
                dialogue_ids[indexer][key] = dialogue

        # get the outputs of the dialogue
        if isinstance(self._text_field_embedder, TextFieldEmbedder):
            embedder_outputs = self._text_field_embedder(dialogue_ids)
        else:
            embedder_outputs = self._text_field_embedder(
                **dialogue_ids[self._index_name])

        # get the outputs of the query and context
        # [B, _len, embed_size]
        context_last_layer = embedder_outputs[:, :context_len].contiguous()
        query_last_layer = embedder_outputs[:, context_len:].contiguous()

        output_dict = {}
        # --------- cls任务:判断是否需要改写 ------------------
        if self._cls_task:
            # 获取cls表征, [B, embed_size]
            cls_embed = context_last_layer[:, 0]
            # 经过线性层分类, [B, 2]
            cls_logits = self._cls_linear(cls_embed)
            output_dict["cls_logits"] = cls_logits
        else:
            cls_logits = None

        # --------- mask任务:判断query中需要填充的位置 -----------
        if self._mask_task:
            # 经过线性层,[B, _len, 2]
            mask_logits = self._mask_linear(query_last_layer)
            output_dict["mask_logits"] = mask_logits
        else:
            mask_logits = None

        if cls_label is not None:
            output_dict["loss"] = self._calc_loss(cls_label, mask_label,
                                                  cls_logits, mask_logits,
                                                  query_mask)

        return output_dict
Esempio n. 11
0
    def forward(self, transactions: TextFieldTensors,
                **kwargs) -> Dict[str, torch.Tensor]:

        mask = get_text_field_mask(transactions)

        # shape (batch_size, timesteps, embedding_size)
        embeddings = self._text_field_embedder(transactions)

        # Either the top layer or all layers.
        contextual_embeddings: Union[
            torch.Tensor,
            List[torch.Tensor]] = self._contextualizer(embeddings, mask)

        return_dict = {}

        # If we have target transactions, calculate the loss.
        token_id_dict = transactions.get("tokens")
        if token_id_dict is not None:
            token_ids = token_id_dict["tokens"]
            assert isinstance(contextual_embeddings, torch.Tensor)

            # Use token_ids to compute targets
            forward_targets = torch.zeros_like(token_ids)
            forward_targets[:, 0:-1] = token_ids[:, 1:]

            if self._bidirectional:
                backward_targets = torch.zeros_like(token_ids)
                backward_targets[:, 1:] = token_ids[:, 0:-1]
            else:
                backward_targets = None

            # add dropout
            contextual_embeddings_with_dropout = self._dropout(
                contextual_embeddings)

            # compute softmax loss
            forward_loss, backward_loss = self._compute_loss(
                contextual_embeddings_with_dropout,
                embeddings,
                forward_targets,
                backward_targets,
            )

            num_targets = torch.sum((forward_targets > 0).long())
            if num_targets > 0:
                if self._bidirectional:
                    average_loss = 0.5 * (forward_loss +
                                          backward_loss) / num_targets.float()
                else:
                    average_loss = forward_loss / num_targets.float()
            else:
                average_loss = torch.tensor(0.0).to(forward_targets.device)

            self._perplexity(average_loss)

            if num_targets > 0:
                return_dict.update({
                    "loss":
                    average_loss,
                    "forward_loss":
                    forward_loss / num_targets.float(),
                    "batch_weight":
                    num_targets.float(),
                })
                if backward_loss is not None:
                    return_dict[
                        "backward_loss"] = backward_loss / num_targets.float()
            else:
                # average_loss zero tensor, return it for all
                return_dict.update({
                    "loss": average_loss,
                    "forward_loss": average_loss
                })
                if backward_loss is not None:
                    return_dict["backward_loss"] = average_loss

        return_dict.update({
            # Note: These embeddings do not have dropout applied.
            "lm_embeddings": contextual_embeddings,
            "noncontextual_token_embeddings": embeddings,
            "mask": mask,
        })

        return return_dict
Esempio n. 12
0
    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        mask_positions: torch.LongTensor,
        target_ids: TextFieldTensors = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        tokens : ``TextFieldTensors``
            The output of ``TextField.as_tensor()`` for a batch of sentences.
        mask_positions : ``torch.LongTensor``
            The positions in ``tokens`` that correspond to [MASK] tokens that we should try to fill
            in.  Shape should be (batch_size, num_masks).
        target_ids : ``TextFieldTensors``
            This is a list of token ids that correspond to the mask positions we're trying to fill.
            It is the output of a ``TextField``, purely for convenience, so we can handle wordpiece
            tokenizers and such without having to do crazy things in the dataset reader.  We assume
            that there is exactly one entry in the dictionary, and that it has a shape identical to
            ``mask_positions`` - one target token per mask position.
        """

        targets = None
        if target_ids is not None:
            # A bit of a hack to get the right targets out of the TextField output...
            if len(target_ids) != 1:
                targets = target_ids["bert"]["token_ids"]
            else:
                targets = list(target_ids.values())[0]["tokens"]
        mask_positions = mask_positions.squeeze(-1)
        batch_size, num_masks = mask_positions.size()
        if targets is not None and targets.size() != mask_positions.size():
            raise ValueError(
                f"Number of targets ({targets.size()}) and number of masks "
                f"({mask_positions.size()}) are not equal")

        # Shape: (batch_size, num_tokens, embedding_dim)
        embeddings = self._text_field_embedder(tokens)

        # Shape: (batch_size, num_tokens, encoding_dim)
        if self._contextualizer:
            mask = util.get_text_field_mask(embeddings)
            contextual_embeddings = self._contextualizer(embeddings, mask)
        else:
            contextual_embeddings = embeddings

        # Does advanced indexing to get the embeddings of just the mask positions, which is what
        # we're trying to predict.
        batch_index = torch.arange(0, batch_size).long().unsqueeze(1)
        mask_embeddings = contextual_embeddings[batch_index, mask_positions]

        target_logits = self._language_model_head(
            self._dropout(mask_embeddings))

        vocab_size = target_logits.size(-1)
        probs = torch.nn.functional.softmax(target_logits, dim=-1)
        k = min(vocab_size,
                5)  # min here largely because tests use small vocab
        top_probs, top_indices = probs.topk(k=k, dim=-1)

        output_dict = {"probabilities": top_probs, "top_indices": top_indices}

        output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(
            tokens)

        if targets is not None:
            target_logits = target_logits.view(batch_size * num_masks,
                                               vocab_size)
            targets = targets.view(batch_size * num_masks)
            loss = torch.nn.functional.cross_entropy(target_logits, targets)
            self._perplexity(loss)
            output_dict["loss"] = loss

        return output_dict
Esempio n. 13
0
    def forward(self,
                context_ids: TextFieldTensors,
                query_ids: TextFieldTensors,
                extend_context_ids: torch.Tensor,
                extend_query_ids: torch.Tensor,
                context_len: torch.Tensor,
                query_len: torch.Tensor,
                oovs_len: torch.Tensor,
                rewrite_input_ids: Optional[TextFieldTensors] = None,
                rewrite_target_ids: Optional[TextFieldTensors] = None,
                extend_rewrite_ids: Optional[torch.Tensor] = None,
                rewrite_len: Optional[torch.Tensor] = None,
                metadata: Optional[List[Dict[str, Any]]] = None):
        """
        这里的通用的id都是allennlp中默认的TextFieldTensors类型
        而extend_context_ids则是我们在数据预处理时转换好的
        context, query和rewrite等_len,主要用于获取mask向量
        """
        # 获取context和query的token_ids
        context_token_ids = context_ids[self._index_name]["token_ids"]
        query_token_ids = query_ids[self._index_name]["token_ids"]

        context_mask = context_ids[self._index_name]["mask"]
        query_mask = query_ids[self._index_name]["mask"]

        # get the extended context and query ids
        extend_context_ids = context_token_ids + extend_context_ids.to(
            dtype=torch.long)
        extend_query_ids = query_token_ids + extend_query_ids.to(
            dtype=torch.long)
        # ---------- bert编码器计算输出 ---------------
        # 需要将context和query拼接在一起编码
        indexers = context_ids.keys()
        dialogue_ids = {}
        for indexer in indexers:
            # get the various variables of context and query
            dialogue_ids[indexer] = {}
            for key in context_ids[indexer].keys():
                context = context_ids[indexer][key]
                query = query_ids[indexer][key]
                # concat the context and query in the length dim
                dialogue = torch.cat([context, query], dim=1)
                dialogue_ids[indexer][key] = dialogue

        # 计算编码
        dialogue_output = self._text_field_embedder(dialogue_ids)
        context_output, query_output, dec_init_state = self._run_encoder(
            dialogue_output, context_mask, query_mask)

        output_dict = {"metadata": metadata}
        if self.training:
            rewrite_input_token_ids = rewrite_input_ids[
                self._index_name]["token_ids"]
            rewrite_input_mask = rewrite_input_ids[self._index_name]["mask"]
            rewrite_target_ids = rewrite_target_ids[
                self._index_name]["token_ids"]
            rewrite_target_ids = rewrite_target_ids + extend_rewrite_ids.to(
                dtype=torch.long)

            # [B, rewrite_len, encoder_output_dim]
            rewrite_embed = self._get_embeddings(rewrite_input_token_ids)
            new_output_dict = self._forward_step(
                context_output, query_output, context_mask, query_mask,
                rewrite_embed, rewrite_target_ids, rewrite_len,
                rewrite_input_mask, extend_context_ids, extend_query_ids,
                oovs_len, dec_init_state)
            output_dict.update(new_output_dict)
        else:
            batch_hyps = self._run_inference(context_output,
                                             query_output,
                                             context_mask,
                                             query_mask,
                                             extend_context_ids,
                                             extend_query_ids,
                                             oovs_len,
                                             dec_init_state=dec_init_state)
            # get the result of each instance
            output_dict['hypothesis'] = batch_hyps
            output_dict = self.get_rewrite_string(output_dict)
            output_dict["loss"] = torch.tensor(0)
        return output_dict
Esempio n. 14
0
    def forward(self,
                left_contexts: TextFieldTensors,
                right_contexts: TextFieldTensors,
                targets: TextFieldTensors,
                target_sentiments: torch.LongTensor = None,
                metadata: torch.LongTensor = None,
                **kwargs) -> Dict[str, torch.Tensor]:
        '''
        The text and targets are Dictionaries as they are text fields they can 
        be represented many different ways e.g. just words or words and chars 
        etc therefore the dictionary represents these different ways e.g. 
        {'words': words_tensor_ids, 'chars': char_tensor_ids}
        '''
        # This is required if the input is of shape greater than 3 dim e.g.
        # character input where it is
        # (batch size, number targets, token length, char length)
        targets_mask = util.get_text_field_mask(targets, num_wrapping_dims=1)
        targets_mask = (targets_mask.sum(dim=-1) >= 1).type(torch.int64)
        batch_size, number_targets = targets_mask.shape
        batch_size_num_targets = batch_size * number_targets

        temp_left_contexts = elmo_input_reshape(left_contexts, batch_size,
                                                number_targets,
                                                batch_size_num_targets)
        left_embedded_text = self.context_field_embedder(temp_left_contexts)
        left_embedded_text = elmo_input_reverse(left_embedded_text,
                                                left_contexts, batch_size,
                                                number_targets,
                                                batch_size_num_targets)
        left_embedded_text = self._time_variational_dropout(left_embedded_text)
        left_text_mask = util.get_text_field_mask(left_contexts,
                                                  num_wrapping_dims=1)

        temp_right_contexts = elmo_input_reshape(right_contexts, batch_size,
                                                 number_targets,
                                                 batch_size_num_targets)
        right_embedded_text = self.context_field_embedder(temp_right_contexts)
        right_embedded_text = elmo_input_reverse(right_embedded_text,
                                                 right_contexts, batch_size,
                                                 number_targets,
                                                 batch_size_num_targets)
        right_embedded_text = self._time_variational_dropout(
            right_embedded_text)
        right_text_mask = util.get_text_field_mask(right_contexts,
                                                   num_wrapping_dims=1)
        if self.target_encoder:
            temp_target = elmo_input_reshape(targets, batch_size,
                                             number_targets,
                                             batch_size_num_targets)
            if self.target_field_embedder:
                embedded_target = self.target_field_embedder(temp_target)
            else:
                embedded_target = self.context_field_embedder(temp_target)
            embedded_target = elmo_input_reverse(embedded_target, targets,
                                                 batch_size, number_targets,
                                                 batch_size_num_targets)
            embedded_target = self._time_variational_dropout(embedded_target)
            target_text_mask = util.get_text_field_mask(targets,
                                                        num_wrapping_dims=1)

            target_encoded_text = self.target_encoder(embedded_target,
                                                      target_text_mask)
            target_encoded_text = self._naive_dropout(target_encoded_text)
            # Encoded target to be of dimension (batch, Number of Targets, words, dim)
            # currently (batch, Number of Targets, dim)
            target_encoded_text = target_encoded_text.unsqueeze(2)

            # Need to repeat the target word for each word in the left
            # and right word.
            left_num_padded = left_embedded_text.shape[2]
            right_num_padded = right_embedded_text.shape[2]

            left_targets = target_encoded_text.repeat(
                (1, 1, left_num_padded, 1))
            right_targets = target_encoded_text.repeat(
                (1, 1, right_num_padded, 1))
            # Add the target to each word in the left and right contexts
            left_embedded_text = torch.cat((left_embedded_text, left_targets),
                                           -1)
            right_embedded_text = torch.cat(
                (right_embedded_text, right_targets), -1)

        left_encoded_text = self.left_text_encoder(left_embedded_text,
                                                   left_text_mask)
        left_encoded_text = self._naive_dropout(left_encoded_text)

        right_encoded_text = self.right_text_encoder(right_embedded_text,
                                                     right_text_mask)
        right_encoded_text = self._naive_dropout(right_encoded_text)

        encoded_left_right = torch.cat([left_encoded_text, right_encoded_text],
                                       dim=-1)

        if self.inter_target_encoding is not None:
            encoded_left_right = self.inter_target_encoding(
                encoded_left_right, targets_mask)
            encoded_left_right = self._variational_dropout(encoded_left_right)

        if self.feedforward:
            encoded_left_right = self.feedforward(encoded_left_right)
        logits = self.label_projection(encoded_left_right)

        masked_class_probabilities = util.masked_softmax(
            logits, targets_mask.unsqueeze(-1))

        output_dict = {
            "class_probabilities": masked_class_probabilities,
            "targets_mask": targets_mask
        }
        # Convert it to bool tensor.
        targets_mask = targets_mask == 1

        if target_sentiments is not None:
            # gets the loss per target instance due to the average=`token`
            if self.loss_weights is not None:
                loss = util.sequence_cross_entropy_with_logits(
                    logits,
                    target_sentiments,
                    targets_mask,
                    average='token',
                    alpha=self.loss_weights)
            else:
                loss = util.sequence_cross_entropy_with_logits(
                    logits, target_sentiments, targets_mask, average='token')
            for metrics in [self.metrics, self.f1_metrics]:
                for metric in metrics.values():
                    metric(logits, target_sentiments, targets_mask)
            output_dict["loss"] = loss

        if metadata is not None:
            words = []
            texts = []
            targets = []
            target_words = []
            for sample in metadata:
                words.append(sample['text words'])
                texts.append(sample['text'])
                targets.append(sample['targets'])
                target_words.append(sample['target words'])
            output_dict["words"] = words
            output_dict["text"] = texts
            output_dict["targets"] = targets
            output_dict["target words"] = target_words

        return output_dict
Esempio n. 15
0
    def forward(self,
                context_ids: TextFieldTensors,
                query_ids: TextFieldTensors,
                context_lens: torch.Tensor,
                query_lens: torch.Tensor,
                mask_label: Optional[torch.Tensor] = None,
                start_label: Optional[torch.Tensor] = None,
                end_label: Optional[torch.Tensor] = None,
                metadata: List[Dict[str, Any]] = None):
        # concat the context and query to the encoder
        # get the indexers first
        indexers = context_ids.keys()
        dialogue_ids = {}

        # 获取context和query的长度
        context_len = torch.max(context_lens).item()
        query_len = torch.max(query_lens).item()

        # [B, _len]
        context_mask = get_mask_from_sequence_lengths(context_lens,
                                                      context_len)
        query_mask = get_mask_from_sequence_lengths(query_lens, query_len)
        for indexer in indexers:
            # get the various variables of context and query
            dialogue_ids[indexer] = {}
            for key in context_ids[indexer].keys():
                context = context_ids[indexer][key]
                query = query_ids[indexer][key]
                # concat the context and query in the length dim
                dialogue = torch.cat([context, query], dim=1)
                dialogue_ids[indexer][key] = dialogue

        # get the outputs of the dialogue
        if isinstance(self._text_field_embedder, TextFieldEmbedder):
            embedder_outputs = self._text_field_embedder(dialogue_ids)
        else:
            embedder_outputs = self._text_field_embedder(
                **dialogue_ids[self._index_name])

        # get the outputs of the query and context
        # [B, _len, embed_size]
        context_last_layer = embedder_outputs[:, :context_len].contiguous()
        query_last_layer = embedder_outputs[:, context_len:].contiguous()

        # ------- 计算span预测的结果 -------
        # 我们想要知道query中的每一个mask位置的token后面需要补充的内容
        # 也就是其对应的context中span的start和end的位置
        # 同理,将context扩展成 [b, query_len, context_len, embed_size]
        context_last_layer = context_last_layer.unsqueeze(dim=1).expand(
            -1, query_len, -1, -1).contiguous()
        # [b, query_len, context_len]
        context_expand_mask = context_mask.unsqueeze(dim=1).expand(
            -1, query_len, -1).contiguous()

        # 将上面3个部分拼接在一起
        # 这里表示query中所有的position
        span_embed_size = context_last_layer.size(-1)

        if self.training and self._neg_sample_ratio > 0.0:
            # 对mask中0的位置进行采样
            # [B*query_len, ]
            sample_mask_label = mask_label.view(-1)
            # 获取展开之后的长度以及需要采样的负样本的数量
            mask_length = sample_mask_label.size(0)
            mask_sum = int(
                torch.sum(sample_mask_label).item() * self._neg_sample_ratio)
            mask_sum = max(10, mask_sum)
            # 获取需要采样的负样本的索引
            neg_indexes = torch.randint(low=0,
                                        high=mask_length,
                                        size=(mask_sum, ))
            # 限制在长度范围内
            neg_indexes = neg_indexes[:mask_length]
            # 将负样本对应的位置mask置为1
            sample_mask_label[neg_indexes] = 1
            # [B, query_len]
            use_mask_label = sample_mask_label.view(
                -1, query_len).to(dtype=torch.bool)
            # 过滤掉query中pad的部分, [B, query_len]
            use_mask_label = use_mask_label & query_mask
            span_mask = use_mask_label.unsqueeze(dim=-1).unsqueeze(dim=-1)
            # 选择context部分可以使用的内容
            # [B_mask, context_len, span_embed_size]
            span_context_matrix = context_last_layer.masked_select(
                span_mask).view(-1, context_len, span_embed_size).contiguous()
            # 选择query部分可以使用的向量
            span_query_vector = query_last_layer.masked_select(
                span_mask.squeeze(dim=-1)).view(-1,
                                                span_embed_size).contiguous()
            span_context_mask = context_expand_mask.masked_select(
                span_mask.squeeze(dim=-1)).view(-1, context_len).contiguous()
        else:
            use_mask_label = query_mask
            span_mask = use_mask_label.unsqueeze(dim=-1).unsqueeze(dim=-1)
            # 选择context部分可以使用的内容
            # [B_mask, context_len, span_embed_size]
            span_context_matrix = context_last_layer.masked_select(
                span_mask).view(-1, context_len, span_embed_size).contiguous()
            # 选择query部分可以使用的向量
            span_query_vector = query_last_layer.masked_select(
                span_mask.squeeze(dim=-1)).view(-1,
                                                span_embed_size).contiguous()
            span_context_mask = context_expand_mask.masked_select(
                span_mask.squeeze(dim=-1)).view(-1, context_len).contiguous()

        # 得到span属于每个位置的logits
        # [B_mask, context_len]
        span_start_probs = self.start_attention(span_query_vector,
                                                span_context_matrix,
                                                span_context_mask)
        span_end_probs = self.end_attention(span_query_vector,
                                            span_context_matrix,
                                            span_context_mask)

        span_start_logits = torch.log(span_start_probs + self._eps)
        span_end_logits = torch.log(span_end_probs + self._eps)

        # [B_mask, 2],最后一个维度第一个表示start的位置,第二个表示end的位置
        best_spans = get_best_span(span_start_logits, span_end_logits)
        # 计算得到每个best_span的分数
        best_span_scores = (
            torch.gather(span_start_logits, 1, best_spans[:, 0].unsqueeze(1)) +
            torch.gather(span_end_logits, 1, best_spans[:, 1].unsqueeze(1)))
        # [B_mask, ]
        best_span_scores = best_span_scores.squeeze(1)

        # 将重要的信息写入到输出中
        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_spans": best_spans,
            "best_span_scores": best_span_scores
        }

        # 如果存在标签,则使用标签计算loss
        if start_label is not None:
            loss = self._calc_loss(span_start_logits, span_end_logits,
                                   use_mask_label, start_label, end_label,
                                   best_spans)
            output_dict["loss"] = loss
        if metadata is not None:
            predict_rewrite_results = self._get_rewrite_result(
                use_mask_label, best_spans, query_lens, context_lens, metadata)
            output_dict['rewrite_results'] = predict_rewrite_results
        return output_dict
Esempio n. 16
0
    def forward(self,
                tokens: TextFieldTensors,
                targets: TextFieldTensors,
                target_sentiments: torch.LongTensor = None,
                target_sequences: Optional[torch.LongTensor] = None,
                metadata: torch.LongTensor = None,
                position_weights: Optional[torch.LongTensor] = None,
                position_embeddings: Optional[Dict[str,
                                                   torch.LongTensor]] = None,
                **kwargs) -> Dict[str, torch.Tensor]:
        '''
        The text and targets are Dictionaries as they are text fields they can 
        be represented many different ways e.g. just words or words and chars 
        etc therefore the dictionary represents these different ways e.g. 
        {'words': words_tensor_ids, 'chars': char_tensor_ids}
        '''
        # Get masks for the targets before they get manipulated
        targets_mask = util.get_text_field_mask(targets, num_wrapping_dims=1)
        # This is required if the input is of shape greater than 3 dim e.g.
        # character input where it is
        # (batch size, number targets, token length, char length)
        label_mask = (targets_mask.sum(dim=-1) >= 1).type(torch.int64)
        batch_size, number_targets = label_mask.shape
        batch_size_num_targets = batch_size * number_targets

        # Embed and encode text as a sequence
        embedded_context = self.context_field_embedder(tokens)
        embedded_context = self._variational_dropout(embedded_context)
        context_mask = util.get_text_field_mask(tokens)
        # Need to repeat the so it is of shape:
        # (Batch Size * Number Targets, Sequence Length, Dim) Currently:
        # (Batch Size, Sequence Length, Dim)
        batch_size, context_sequence_length, context_embed_dim = embedded_context.shape
        reshaped_embedding_context = embedded_context.unsqueeze(1).repeat(
            1, number_targets, 1, 1)
        reshaped_embedding_context = reshaped_embedding_context.view(
            batch_size_num_targets, context_sequence_length, context_embed_dim)
        # Embed and encode target as a sequence. If True here the target
        # embeddings come from the context.
        if self._use_target_sequences:
            _, _, target_sequence_length, target_index_length = target_sequences.shape
            target_index_len_err = (
                'The size of the context sequence '
                f'{context_sequence_length} is not the same'
                ' as the target index sequence '
                f'{target_index_length}. This is to get '
                'the contextualized target through the context')
            assert context_sequence_length == target_index_length, target_index_len_err
            seq_targets_mask = target_sequences.view(batch_size_num_targets,
                                                     target_sequence_length,
                                                     target_index_length)
            reshaped_embedding_targets = torch.matmul(
                seq_targets_mask.type(torch.float32),
                reshaped_embedding_context)
        else:
            temp_targets = elmo_input_reshape(targets, batch_size,
                                              number_targets,
                                              batch_size_num_targets)
            if self.target_field_embedder:
                embedded_targets = self.target_field_embedder(temp_targets)
            else:
                embedded_targets = self.context_field_embedder(temp_targets)
                embedded_targets = elmo_input_reverse(embedded_targets,
                                                      targets, batch_size,
                                                      number_targets,
                                                      batch_size_num_targets)

            # Size (batch size, num targets, target sequence length, embedding dim)
            embedded_targets = self._time_variational_dropout(embedded_targets)
            batch_size, number_targets, target_sequence_length, target_embed_dim = embedded_targets.shape
            reshaped_embedding_targets = embedded_targets.view(
                batch_size_num_targets, target_sequence_length,
                target_embed_dim)

        encoded_targets_mask = targets_mask.view(batch_size_num_targets,
                                                 target_sequence_length)
        # Shape (Batch Size * Number targets), encoded dim
        encoded_targets_seq = self.target_encoder(reshaped_embedding_targets,
                                                  encoded_targets_mask)
        encoded_targets_seq = self._naive_dropout(encoded_targets_seq)

        repeated_context_mask = context_mask.unsqueeze(1).repeat(
            1, number_targets, 1)
        repeated_context_mask = repeated_context_mask.view(
            batch_size_num_targets, context_sequence_length)
        # Need to concat the target embeddings to the context words
        repeated_encoded_targets = encoded_targets_seq.unsqueeze(1).repeat(
            1, context_sequence_length, 1)
        if self._AE:
            reshaped_embedding_context = torch.cat(
                (reshaped_embedding_context, repeated_encoded_targets), -1)
        # add position embeddings if required.
        reshaped_embedding_context = concat_position_embeddings(
            reshaped_embedding_context, position_embeddings,
            self.target_position_embedding)
        # Size (batch size * number targets, sequence length, embedding dim)
        reshaped_encoded_context_seq = self.context_encoder(
            reshaped_embedding_context, repeated_context_mask)
        reshaped_encoded_context_seq = self._variational_dropout(
            reshaped_encoded_context_seq)
        # Weighted position information encoded into the context sequence.
        if self.target_position_weight is not None:
            if position_weights is None:
                raise ValueError(
                    'This model requires `position_weights` to '
                    'better encode the target but none were given')
            position_output = self.target_position_weight(
                reshaped_encoded_context_seq, position_weights,
                repeated_context_mask)
            reshaped_encoded_context_seq, weighted_position_weights = position_output
        # Whether to concat the aspect embeddings on to the contextualised word
        # representations
        attention_encoded_context_seq = reshaped_encoded_context_seq
        if self._AttentionAE:
            attention_encoded_context_seq = torch.cat(
                (attention_encoded_context_seq, repeated_encoded_targets), -1)
        _, _, attention_encoded_dim = attention_encoded_context_seq.shape

        # Projection layer before the attention layer
        attention_encoded_context_seq = self.attention_project_layer(
            attention_encoded_context_seq)
        attention_encoded_context_seq = self._context_attention_activation_function(
            attention_encoded_context_seq)
        attention_encoded_context_seq = self._variational_dropout(
            attention_encoded_context_seq)

        # Attention over the context sequence
        attention_vector = self.attention_vector.unsqueeze(0).repeat(
            batch_size_num_targets, 1)
        attention_weights = self.context_attention_layer(
            attention_vector, attention_encoded_context_seq,
            repeated_context_mask)
        expanded_attention_weights = attention_weights.unsqueeze(-1)
        weighted_encoded_context_seq = reshaped_encoded_context_seq * expanded_attention_weights
        weighted_encoded_context_vec = weighted_encoded_context_seq.sum(dim=1)

        # Add the last hidden state of the context vector, with the attention vector
        context_final_states = util.get_final_encoder_states(
            reshaped_encoded_context_seq,
            repeated_context_mask,
            bidirectional=self.context_encoder_bidirectional)
        context_final_states = self.final_hidden_state_projection_layer(
            context_final_states)
        weighted_encoded_context_vec = self.final_attention_projection_layer(
            weighted_encoded_context_vec)
        feature_vector = context_final_states + weighted_encoded_context_vec
        feature_vector = self._naive_dropout(feature_vector)
        # Reshape the vector into (Batch Size, Number Targets, number labels)
        _, feature_dim = feature_vector.shape
        feature_target_seq = feature_vector.view(batch_size, number_targets,
                                                 feature_dim)

        if self.inter_target_encoding is not None:
            feature_target_seq = self.inter_target_encoding(
                feature_target_seq, label_mask)
            feature_target_seq = self._variational_dropout(feature_target_seq)

        if self.feedforward is not None:
            feature_target_seq = self.feedforward(feature_target_seq)

        logits = self.label_projection(feature_target_seq)
        masked_class_probabilities = util.masked_softmax(
            logits, label_mask.unsqueeze(-1))
        output_dict = {
            "class_probabilities": masked_class_probabilities,
            "targets_mask": label_mask
        }
        # Convert it to bool tensor.
        label_mask = label_mask == 1

        if target_sentiments is not None:
            # gets the loss per target instance due to the average=`token`
            if self.loss_weights is not None:
                loss = util.sequence_cross_entropy_with_logits(
                    logits,
                    target_sentiments,
                    label_mask,
                    average='token',
                    alpha=self.loss_weights)
            else:
                loss = util.sequence_cross_entropy_with_logits(
                    logits, target_sentiments, label_mask, average='token')
            for metrics in [self.metrics, self.f1_metrics]:
                for metric in metrics.values():
                    metric(logits, target_sentiments, label_mask)
            output_dict["loss"] = loss

        if metadata is not None:
            words = []
            texts = []
            targets = []
            target_words = []
            for batch_index, sample in enumerate(metadata):
                words.append(sample['text words'])
                texts.append(sample['text'])
                targets.append(sample['targets'])
                target_words.append(sample['target words'])

            output_dict["words"] = words
            output_dict["text"] = texts
            word_attention_weights = attention_weights.view(
                batch_size, number_targets, context_sequence_length)
            output_dict["word_attention"] = word_attention_weights
            output_dict["targets"] = targets
            output_dict["target words"] = target_words
            output_dict["context_mask"] = context_mask

        return output_dict