Beispiel #1
0
def flattened_index_select(target: torch.Tensor,
                           indices: torch.LongTensor) -> torch.Tensor:
    """
    The given ``indices`` of size ``(set_size, subset_size)`` specifies subsets of the ``target``
    that each of the set_size rows should select. The `target` has size
    ``(batch_size, sequence_length, embedding_size)``, and the resulting selected tensor has size
    ``(batch_size, set_size, subset_size, embedding_size)``.

    Parameters
    ----------
    target : ``torch.Tensor``, required.
        A Tensor of shape (batch_size, sequence_length, embedding_size).
    indices : ``torch.LongTensor``, required.
        A LongTensor of shape (set_size, subset_size). All indices must be < sequence_length
        as this tensor is an index into the sequence_length dimension of the target.

    Returns
    -------
    selected : ``torch.Tensor``, required.
        A Tensor of shape (batch_size, set_size, subset_size, embedding_size).
    """
    if indices.dim() != 2:
        raise ConfigurationError("Indices passed to flattened_index_select had shape {} but "
                                 "only 2 dimensional inputs are supported.".format(indices.size()))
    # Shape: (batch_size, set_size * subset_size, embedding_size)
    flattened_selected = target.index_select(1, indices.view(-1))

    # Shape: (batch_size, set_size, subset_size, embedding_size)
    selected = flattened_selected.view(target.size(0), indices.size(0), indices.size(1), -1)
    return selected
Beispiel #2
0
 def extract_features(self,
                      tokens: torch.LongTensor,
                      return_all_hiddens: bool = False) -> torch.Tensor:
     if tokens.dim() == 1:
         tokens = tokens.unsqueeze(0)
     if tokens.size(-1) > self.model.max_positions():
         raise ValueError('tokens exceeds maximum length: {} > {}'.format(
             tokens.size(-1), self.model.max_positions()))
     features, extra = self.model(
         tokens.to(device=self.device),
         features_only=True,
         return_all_hiddens=return_all_hiddens,
     )
     if return_all_hiddens:
         # convert from T x B x C -> B x T x C
         inner_states = extra['inner_states']
         return [
             inner_state.transpose(0, 1) for inner_state in inner_states
         ]
     else:
         return features  # just the last layer's features
Beispiel #3
0
    def decode(
        self,
        tokens: torch.LongTensor,
        skip_special_tokens: bool = True,
        remove_bpe: bool = True,
    ) -> str:
        assert tokens.dim() == 1
        tokens = tokens.numpy()

        if tokens[0] == self.task.source_dictionary.bos(
        ) and skip_special_tokens:
            tokens = tokens[1:]  # remove <s>

        eos_mask = tokens == self.task.source_dictionary.eos()
        doc_mask = eos_mask[1:] & eos_mask[:-1]
        sentences = np.split(tokens, doc_mask.nonzero()[0] + 1)

        if skip_special_tokens:
            sentences = [
                np.array(
                    [c
                     for c in s
                     if c != self.task.source_dictionary.eos()])
                for s in sentences
            ]

        sentences = [
            " ".join([self.task.source_dictionary.symbols[c]
                      for c in s])
            for s in sentences
        ]

        if remove_bpe:
            sentences = [
                s.replace(" ", "").replace("▁", " ").strip() for s in sentences
            ]
        if len(sentences) == 1:
            return sentences[0]
        return sentences
Beispiel #4
0
 def decode(self, tokens: torch.LongTensor, dict):
     assert tokens.dim() == 1
     tokens = tokens.cpu().numpy()
     if tokens[0] == self.src_dict.bos():
         tokens = tokens[1:]  # remove <s>
     eos_mask = (tokens == self.src_dict.eos())
     doc_mask = eos_mask[1:] & eos_mask[:-1]
     sentences = np.split(tokens, doc_mask.nonzero()[0] + 1)
     new_sentences = []
     for s in sentences:
         _s = dict.string(s,
                          extra_symbols_to_ignore=[
                              0, 1, 2, 3, 50262, 50263, 50264, 50265
                          ])
         #print(s, _s)
         _s = self.bpe.decode(_s)
         new_sentences.append(_s)
     sentences = new_sentences
     #sentences = [self.bpe.decode(self.src_dict.string(s)) for s in sentences]
     if len(sentences) == 1:
         return sentences[0]
     return sentences
Beispiel #5
0
    def _compute_joint_llh(self, emissions: torch.Tensor,
                           tags: torch.LongTensor,
                           mask: torch.ByteTensor) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # tags: (seq_length, batch_size)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and tags.dim() == 2
        assert emissions.size()[:2] == tags.size()
        assert emissions.size(2) == self.num_tags
        assert mask.size() == tags.size()
        assert all(mask[0])

        seq_length = emissions.size(0)
        mask = mask.float()

        # Start transition score
        llh = self.start_transitions[tags[0]]  # (batch_size,)

        for i in range(seq_length - 1):
            cur_tag, next_tag = tags[i], tags[i + 1]
            # Emission score for current tag
            llh += emissions[i].gather(1, cur_tag.view(-1,
                                                       1)).squeeze(1) * mask[i]
            # Transition score to next tag
            transition_score = self.transitions[cur_tag, next_tag]
            # Only add transition score if the next tag is not masked (mask == 1)
            llh += transition_score * mask[i + 1]

        # Find last tag index
        last_tag_indices = mask.long().sum(0) - 1  # (batch_size,)
        last_tags = tags.gather(0, last_tag_indices.view(1, -1)).squeeze(0)

        # End transition score
        llh += self.end_transitions[last_tags]
        # Emission score for the last tag, if mask is valid (mask == 1)
        llh += emissions[-1].gather(1, last_tags.view(-1,
                                                      1)).squeeze(1) * mask[-1]

        return llh
Beispiel #6
0
    def _compute_score(self, emissions: torch.Tensor, tags: torch.LongTensor,
                       mask: torch.ByteTensor) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # tags: (seq_length, batch_size)
        # mask: (seq_length, batch_size)

        assert emissions.dim() == 3 and tags.dim() == 2
        assert emissions.shape[:2] == tags.shape
        assert emissions.size(2) == self.num_tags
        assert mask.shape == tags.shape

        assert mask[0].all()

        seq_length, batch_size = tags.shape
        mask = mask.float()

        # Start transition score and first emission
        # shape: (batch_size,)
        score = self.start_transitions[tags[0]]
        score += emissions[0, torch.arange(batch_size), tags[0]]

        for i in range(1, seq_length):
            # Transition score to next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += self.transitions[tags[i - 1], tags[i]] * mask[i]

            # Emission score for next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]

        # End transition score
        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1
        # shape: (batch_size,)
        last_tags = tags[seq_ends, torch.arange(batch_size)]
        # shape: (batch_size,)
        score += self.end_transitions[last_tags]

        return score
Beispiel #7
0
    def _compute_score(self, emissions: torch.Tensor, tags: torch.LongTensor,
                       mask: torch.ByteTensor) -> torch.Tensor:
        assert emissions.dim() == 3 and tags.dim() == 2
        assert emissions.shape[:2] == tags.shape
        assert emissions.size(2) == self.num_tags
        assert mask.shape == tags.shape
        assert mask[0].all()

        seq_length, batch_size = tags.shape
        mask = mask.float()

        score = self.start_transitions[tags[0]]
        score += emissions[0, torch.arange(batch_size), tags[0]]

        for i in range(1, seq_length):
            score += self.transitions[tags[i - 1], tags[i]] * mask[i]
            score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]

        seq_ends = mask.long().sum(dim=0) - 1
        last_tags = tags[seq_ends, torch.arange(batch_size)]
        score += self.end_transitions[last_tags]

        return score
Beispiel #8
0
    def _computer_score(self, emissions: torch.Tensor, tags: torch.LongTensor,
                        mask: torch.ByteTensor) -> torch.Tensor:

        # batch second
        assert emissions.dim() == 3 and tags.dim() == 2
        assert emissions.shape[:2] == tags.shape
        assert emissions.size(2) == self.num_tags
        assert mask.shape == tags.shape
        assert mask[0].all()

        seq_length, batch_size = tags.shape
        mask = mask.float()

        # self.start_transitions  start 到其他tag(不包含end)的得分
        score = self.start_transitions[tags[0]]

        # emissions.shape (seq_len,batch_size,tag_nums)

        score += emissions[0, torch.arange(batch_size), tags[0]]

        for i in range(1, seq_length):

            # if mask[i].sum() == 0:
            #     break

            score += self.transitions[tags[i - 1], tags[i]] * mask[i]

            score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]

        # 这里是为了获取每一个样本最后一个词的tag。
        # shape: (batch_size,)   每一个batch 的真实长度
        seq_ends = mask.long().sum(dim=0) - 1
        # 每个样本最火一个词的tag
        last_tags = tags[seq_ends, torch.arange(batch_size)]
        # shape: (batch_size,) 每一个样本到最后一个词的得分加上之前的score
        score += self.end_transitions[last_tags]
        return score
Beispiel #9
0
    def log_probs(self,
                  heads: LongTensor,
                  types: LongTensor,
                  score_only: bool = False) -> Tensor:
        """Compute the log probability of a labeled dependency tree.

        Args:
            heads: Tensor of shape (B, N) containing the index/position of the head of
                each word.
            types: Tensor of shape (B, N) containing the dependency types for the
                corresponding head-dependent relation.
            score_only: Whether to compute only the score of the tree. Useful for training
                with cross-entropy loss.

        Returns:
            1-D tensor of length B containing the log probabilities.
        """
        assert heads.dim() == 2
        assert types.shape == heads.shape
        assert self.mask is not None

        scores = self.scores
        bsz, slen, _, n_types = self.scores.shape

        # broadcast over types
        heads = heads.unsqueeze(2).expand(bsz, slen, n_types)  # type: ignore
        # shape: (bsz, slen, n_types)
        scores = scores.gather(1, heads.unsqueeze(1)).squeeze(1)
        # shape: (bsz, slen)
        scores = scores.gather(2, types.unsqueeze(2)).squeeze(2)
        # mask scores from invalid dependents
        scores = scores.masked_fill(~self.mask, 0)
        # mask scores of root as dependents
        scores = scores.masked_fill(
            torch.arange(slen).to(scores.device) == self.ROOT, 0)

        return scores.sum(dim=1) - (0 if score_only else self.log_partitions())
Beispiel #10
0
def segment_lengths_to_ids(
        segment_lengths: torch.LongTensor) -> torch.LongTensor:
    """
    Args:
        segment_lengths: Non-negative lengths of the tensor segments

    Returns:
        A tensor containing ids for every element in the tensor to be segmented

    Examples:
        >>> segments = torch.tensor([2, 4, 3, 1])
        >>> segment_lengths_to_slices(segments)
        tensor([0, 0, 1, 1, 1, 1, 2, 2, 2, 3])
    """
    if segment_lengths.dim() != 1:
        raise ValueError(
            f'`segment_lengths` should have a single dimension, got shape {segment_lengths.shape}'
        )
    if (segment_lengths < 0).any():
        raise ValueError(
            f'All entries in `segment_lengths` should be non-negative')

    return segment_lengths.new_tensor(
        np.arange(len(segment_lengths)).repeat(segment_lengths.cpu().numpy()))
Beispiel #11
0
def count_correct(
    heads: LongTensor,
    types: LongTensor,
    pred_heads: LongTensor,
    pred_types: LongTensor,
    mask: BoolTensor,
    nopunct_mask: BoolTensor,
    proj_mask: BoolTensor,
    root_idx: int = 0,
    type_idx: Optional[int] = None,
) -> Union["Counts", "TypeWiseCounts"]:
    # shape: (bsz, slen)
    assert heads.dim() == 2
    assert types.shape == heads.shape
    assert pred_heads.shape == heads.shape
    assert pred_types.shape == heads.shape
    assert mask.shape == heads.shape
    assert nopunct_mask.shape == heads.shape
    assert proj_mask.shape == heads.shape

    corr_heads = heads == pred_heads
    corr_types = types == pred_types

    if type_idx is None:
        root_mask = heads == root_idx
        nonproj_mask = ~torch.all(proj_mask | (~mask), dim=1, keepdim=True)

        usents = int(torch.all(corr_heads | (~mask), dim=1).long().sum())
        usents_nopunct = int(
            torch.all(corr_heads | (~mask) | (~nopunct_mask),
                      dim=1).long().sum())
        lsents = int(
            torch.all(corr_heads & corr_types | (~mask), dim=1).long().sum())
        lsents_nopunct = int(
            torch.all(corr_heads & corr_types | (~mask) | (~nopunct_mask),
                      dim=1).long().sum())
        uarcs = int((corr_heads & mask).long().sum())
        uarcs_nopunct = int((corr_heads & mask & nopunct_mask).long().sum())
        uarcs_nonproj = int((corr_heads & mask & nonproj_mask).long().sum())
        larcs = int((corr_heads & corr_types & mask).long().sum())
        larcs_nopunct = int(
            (corr_heads & corr_types & mask & nopunct_mask).long().sum())
        larcs_nonproj = int(
            (corr_heads & corr_types & mask & nonproj_mask).long().sum())
        roots = int((corr_heads & mask & root_mask).long().sum())
        n_sents = heads.size(0)
        n_arcs = int(mask.long().sum())
        n_arcs_nopunct = int((mask & nopunct_mask).long().sum())
        n_arcs_nonproj = int((mask & nonproj_mask).long().sum())
        n_roots = int((mask & root_mask).long().sum())

        return Counts(
            usents,
            usents_nopunct,
            lsents,
            lsents_nopunct,
            uarcs,
            uarcs_nopunct,
            uarcs_nonproj,
            larcs,
            larcs_nopunct,
            larcs_nonproj,
            roots,
            n_sents,
            n_arcs,
            n_arcs_nopunct,
            n_arcs_nonproj,
            n_roots,
        )

    assert type_idx is not None
    type_mask = types == type_idx

    uarcs = int((corr_heads & type_mask & mask).long().sum())
    uarcs_nopunct = int(
        (corr_heads & type_mask & nopunct_mask & mask).long().sum())
    larcs = int((corr_heads & corr_types & type_mask & mask).long().sum())
    larcs_nopunct = int((corr_heads & corr_types & type_mask & nopunct_mask
                         & mask).long().sum())
    n_arcs = int((type_mask & mask).long().sum())
    n_arcs_nopunct = int((type_mask & nopunct_mask & mask).long().sum())

    return TypeWiseCounts(type_idx, uarcs, uarcs_nopunct, larcs, larcs_nopunct,
                          n_arcs, n_arcs_nopunct)
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            verb_indicator: torch.LongTensor,
            tags: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        verb_indicator: torch.LongTensor, required.
            A one-hot/all-zeros ``IndexField`` representation of the position of the verb in the sentence.
            This should have shape (batch_size, num_tokens) and importantly, can be all zeros, in the case
            that the sentence has no verbal predicate.
        tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of gold labels.  These can either be integer
            indexes or one hot arrays of labels, so of shape ``(batch_size, num_tokens)`` or of
            shape ``(batch_size, num_tokens, num_tags)``.

        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            a distribution of the tag classes per word.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.

        """
        embedded_text_input = self.text_field_embedder(tokens)
        mask = get_text_field_mask(tokens)
        expanded_verb_indicator = verb_indicator.unsqueeze(-1).float()
        # Concatenate the verb feature onto the embedded text. This now
        # has shape (batch_size, sequence_length, embedding_dim + 1).
        embedded_text_with_verb_indicator = torch.cat(
            [embedded_text_input, expanded_verb_indicator], -1)
        batch_size, sequence_length, embedding_dim_with_binary_feature = embedded_text_with_verb_indicator.size(
        )

        if self.stacked_encoder.get_input_dim(
        ) != embedding_dim_with_binary_feature:
            raise ConfigurationError(
                "The SRL model uses an indicator feature, which makes "
                "the embedding dimension one larger than the value "
                "specified. Therefore, the 'input_dim' of the stacked_encoder "
                "must be equal to total_embedding_dim + 1.")

        batch_sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        encoded_text = self.stacked_encoder(embedded_text_with_verb_indicator,
                                            batch_sequence_lengths)

        logits = self.tag_projection_layer(encoded_text)
        reshaped_log_probs = logits.view(-1, self.num_classes)
        class_probabilities = F.softmax(reshaped_log_probs).view(
            [batch_size, sequence_length, self.num_classes])
        output_dict = {
            "logits": logits,
            "class_probabilities": class_probabilities
        }
        if tags is not None:
            # Negative log likelihood criterion takes integer labels, not one hot.
            if tags.dim() == 3:
                _, tags = tags.max(-1)
            loss = sequence_cross_entropy_with_logits(logits, tags, mask)
            self.span_metric(class_probabilities, tags, mask)
            output_dict["loss"] = loss

        return output_dict
Beispiel #13
0
def eval_probability(
        model: transformer.Transformer,
        input_seq: torch.LongTensor,
        target_seq: torch.LongTensor,
        pad_index: int=None
) -> torch.FloatTensor:
    """Computes the probability that the provided model computes a target sequence given an input sequence.
    
    Args:
         model (:class:`transformer.Transformer`): The model to use.
         input_seq (torch.LongTensor): The input sequence to be provided to the model. This has to be a
            (batch-size x input-seq-len)-tensor.
         target_seq (torch.LongTensor): The target sequence whose probability is being evaluated. This has to be a
            (batch-size x target-seq-len)-tensor.
         pad_index (int, optional): The index that indicates a padding token in a sequence. If ``target_seq`` is padded,
            then the ``pad_index`` has to be provided in order to allow for computing the probabilities for relevant
            parts of the target sequence only.
    
    Returns:
        torch.FloatTensor: A 1D-tensor of size (batch-size), which contains one probability for each sample in
            ``input_seq`` and ``target_seq``, respectively.
    """
    if not isinstance(model, transformer.Transformer):
        raise TypeError("The <model> has to be a transformer.Transformer!")
    if not isinstance(input_seq, torch.LongTensor) and not isinstance(input_seq, torch.cuda.LongTensor):
        raise TypeError("The <input_seq> has to be a LongTensor!")
    if input_seq.dim() != 2:
        raise ValueError("<input_seq> has to be a 2D-tensor!")
    if input_seq.is_cuda:
        if not isinstance(target_seq, torch.cuda.LongTensor):
            raise TypeError("The <target_seq> has to be of the same type as <input_seq>, i.e., cuda.LongTensor!")
    elif not isinstance(target_seq, torch.LongTensor):
        raise TypeError("The <target_seq> has to be of the same type as <input_seq>, i.e., LongTensor!")
    if target_seq.dim() != 2:
        raise ValueError("<input_seq> has to be a 2D-tensor!")
    if input_seq.size(0) != target_seq.size(0):
        raise ValueError("<input_seq> and <target_seq> use different batch sizes!")
    if pad_index is not None and not isinstance(pad_index, int):
        raise TypeError("The <pad_index>, if provided, has to be an integer!")
    
    batch_size = input_seq.size(0)
    max_seq_len = input_seq.size(1)
    
    # put model in evaluation mode
    original_mode = model.training  # store original mode (train/eval) to be restored eventually
    model.eval()
    
    # run the model to compute the needed probabilities
    predictions = model(input_seq, target_seq)
    
    # determine the lengths of the target sequences
    if pad_index is not None:
        mask = util.create_padding_mask(target_seq, pad_index)[:, 0, :]
        seq_len = mask.sum(dim=1).cpu().numpy().tolist()
    else:
        seq_len = (np.ones(batch_size, dtype=np.long) * max_seq_len).tolist()
    
    # compute the probabilities for each of the provided samples
    sample_probs = torch.ones(batch_size)
    for sample_idx in range(batch_size):  # iterate over each sample
        for token_idx in range(seq_len[sample_idx]):  # iterate over each position in the output sequence
            sample_probs[sample_idx] *= predictions[sample_idx, token_idx, target_seq[sample_idx, token_idx]].item()

    # restore original mode of the model
    model.train(mode=original_mode)
    
    return sample_probs
Beispiel #14
0
    def forward(
            self,
            tokens: Dict[str, torch.LongTensor],
            spans: torch.LongTensor,
            metadata: List[Dict[str, Any]],
            pos_tags: Dict[str, torch.LongTensor] = None,
            span_labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        embedded_text_input = self.text_field_embedder(tokens)
        if pos_tags is not None and self.pos_tag_embedding is not None:
            embedded_pos_tags = self.pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
        elif self.pos_tag_embedding is not None:
            raise ConfigurationError(
                "Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(tokens)
        # Looking at the span start index is enough to know if
        # this is padding or not. Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()
        if span_mask.dim() == 1:
            # This happens if you use batch_size 1 and encounter
            # a length 1 sentence in PTB, which do exist. -.-
            span_mask = span_mask.unsqueeze(-1)
        if span_labels is not None and span_labels.dim() == 1:
            span_labels = span_labels.unsqueeze(-1)

        num_spans = get_lengths_from_binary_sequence_mask(span_mask)

        encoded_text = self.encoder(embedded_text_input, mask)

        span_representations = self.span_extractor(encoded_text, spans, mask,
                                                   span_mask)

        if self.feedforward_layer is not None:
            span_representations = self.feedforward_layer(span_representations)

        logits = self.tag_projection_layer(span_representations)
        class_probabilities = masked_softmax(logits, span_mask.unsqueeze(-1))

        output_dict = {
            "class_probabilities": class_probabilities,
            "spans": spans,
            "tokens": [meta["tokens"] for meta in metadata],
            "pos_tags": [meta.get("pos_tags") for meta in metadata],
            "num_spans": num_spans
        }
        if span_labels is not None:
            loss = sequence_cross_entropy_with_logits(logits.float(),
                                                      span_labels, span_mask)
            self.tag_accuracy(class_probabilities, span_labels, span_mask)
            output_dict["loss"] = loss

        # The evalb score is expensive to compute, so we only compute
        # it for the validation and test sets.
        batch_gold_trees = [meta.get("gold_tree") for meta in metadata]
        if all(batch_gold_trees
               ) and self._evalb_score is not None and not self.training:
            gold_pos_tags: List[List[str]] = [
                list(zip(*tree.pos()))[1] for tree in batch_gold_trees
            ]
            predicted_trees = self.construct_trees(
                class_probabilities.cpu().data,
                spans.cpu().data, num_spans.data, output_dict["tokens"],
                gold_pos_tags)
            self._evalb_score(predicted_trees, batch_gold_trees)

        return output_dict
    def forward(self,
                indices: torch.LongTensor,
                offsets: Optional[torch.LongTensor] = None,
                per_index_weights: Optional[torch.Tensor] = None):
        """
        Forward process to the embedding bag layer.
        :param indices: Tensor containing bags of indices into the embedding matrix.
        :param offsets: Only used when indices is 1D. offsets determines the starting index position of each bag
        (sequence)in input.
        :param per_index_weights: a tensor of float / double weights, or None to indicate all weights should be taken to
        be 1. If specified, per_sample_weights must have exactly the same shape as input and is treated as having the
        same offsets, if those are not None.
        :return: an #bag x embedding_dim Tensor.
        """

        # always move indices to cpu, as we need to get its corresponding minhash values from table in memory
        indices = indices.cpu()

        # Check input validation.
        if per_index_weights is not None and indices.size() != per_index_weights.size():
            raise ValueError("embedding_bag: If per_index_weights ({}) is not None, "
                             "then it must have the same shape as the indices ({})"
                             .format(per_index_weights.shape, indices.shape))
        if indices.dim() == 2:
            if offsets is not None:
                raise ValueError("if input is 2D, then offsets has to be None"
                                 ", as input is treated is a mini-batch of"
                                 " fixed length sequences. However, found "
                                 "offsets of type {}".format(type(offsets)))
            offsets = torch.arange(0, indices.numel(), indices.size(1), dtype=torch.long, device=indices.device)
            indices = indices.reshape(-1)
            if per_index_weights is not None:
                per_sample_weights = per_index_weights.reshape(-1)
        elif indices.dim() == 1:
            if offsets is None:
                raise ValueError("offsets has to be a 1D Tensor but got None")
            if offsets.dim() != 1:
                raise ValueError("offsets has to be a 1D Tensor")
        else:
            ValueError("input has to be 1D or 2D Tensor,"
                       " but got Tensor of dimension {}".format(input.dim()))

        num_bags = offsets.size(0)

        # get the min-hash for each category value, note that lsh_weight_index is in cpu memory
        lsh_weight_index = self._minhash_table[indices]
        # print("In forward: ", lsh_weight_index, indices, self._minhash_table[indices], self.lsh_weight_size)

        # move the min-hash values to target device
        lsh_weight_index = lsh_weight_index.to(self.hashed_weight.device)
        lsh_weight_index %= self.lsh_weight_size

        # indices_embedding_vector is a |indices| x |embedding_dim| tensor.
        indices_embedding_vectors = self.hashed_weight[lsh_weight_index]
        # print('indices_embedding_vectors: ', lsh_weight_index, indices_embedding_vectors)

        # multiply embedding vectors by weights
        if per_index_weights is not None:
            per_index_weights = per_index_weights.to(indices_embedding_vectors.device)
            indices_embedding_vectors *= per_index_weights[:, None]
        # print("per_index_weights",per_index_weights)
        offsets2bag = make_offset2bag(offsets, indices)
        # print("offsets2bag: ", offsets2bag)
        if self._mode == "sum" or self._mode == "mean":
            result = \
                torch.zeros(num_bags, self.embedding_dim, dtype=indices_embedding_vectors.dtype,
                            device=self.hashed_weight.device)
            result.index_add_(0, offsets2bag, indices_embedding_vectors)
            if self._mode == "sum":
                return result

            # self._mode == "mean":
            bag_size = make_bag_size(offsets, indices).to(result.device)
            result /= bag_size[:, None]
            return result
    def forward(
            self,  # type: ignore
            question_passage: Dict[str, torch.LongTensor],
            sentences_mask: torch.LongTensor,
            sentences_tokens: torch.LongTensor,
            question_sentences_tokens: torch.LongTensor,
            number_indices: torch.LongTensor,
            mask_indices: torch.LongTensor,
            num_spans: torch.LongTensor = None,
            answer_as_passage_spans: torch.LongTensor = None,
            answer_as_question_spans: torch.LongTensor = None,
            answer_as_expressions: torch.LongTensor = None,
            answer_as_expressions_extra: torch.LongTensor = None,
            answer_as_counts: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        # pylint: disable=arguments-differ
        # Shape: (batch_size, seqlen) - all questions and passage tokens
        question_passage_tokens = question_passage["tokens"]
        # Shape: (batch_size, seqlen) - 1 for all question and passage tokens, 0 for padding
        pad_mask = question_passage["mask"]
        # Shape: (batch_size, seqlen) - 0 for all question tokens, 1 for passage tokens, 0 for padding
        seqlen_ids = question_passage["tokens-type-ids"]

        max_seqlen = question_passage_tokens.shape[-1]
        batch_size = question_passage_tokens.shape[0]

        # Shape: (batch_size, 3) - (0, question length - 1, question + passage length - 1)
        mask = mask_indices.squeeze(-1)
        # Shape: (batch_size, seqlen) - 0 for all [CLS] (question start) and [SEP] (passage start) tokens, 1 for rest
        cls_sep_mask = \
            torch.ones(pad_mask.shape, device=pad_mask.device).long().scatter(1, mask, torch.zeros(mask.shape, device=mask.device).long())
        # Shape: (batch_size, seqlen) - 1 for all passage tokens excluding [CLS], [SEP], 0 for others
        passage_mask = seqlen_ids * pad_mask * cls_sep_mask
        # Shape: (batch_size, seqlen) -1 for all question tokens excluding [CLS], [SEP], 0 for others
        question_mask = (1 - seqlen_ids) * pad_mask * cls_sep_mask

        # Contains a single summary vec for each sentence (weighted sum of bert embeddings)
        # Shape: (batch_size, sentence_count, bert_dim)
        with torch.no_grad():
            sentences_embeddings, sentences_embedding_mask = self.extract_sentence_embeddings(
                question_sentences_tokens)
        sentences_summary_vecs = self.summary_vector(sentences_embeddings,
                                                     sentences_embedding_mask,
                                                     "sentence")

        # Shape: (batch_size, seqlen, bert_dim) ; self.BERT is a BertModel
        # seqlen_ids: mask of Sentence A - 0, Sentence B - 1
        # pad_mask: 1 for real tokens, 0 for padding
        bert_out, _ = self.BERT(question_passage_tokens,
                                seqlen_ids,
                                pad_mask,
                                output_all_encoded_layers=False)
        # Shape: (batch_size, qlen, bert_dim)
        question_end = max(
            mask[:, 1]
        )  # Contains the index for last question token: Shape (batch_size, )
        question_out = bert_out[:, :
                                question_end]  # Contains question token embeddings + some unfiltered embeddings from passage
        # Shape: (batch_size, qlen)
        question_mask = question_mask[:, :
                                      question_end]  # Crop mask to match question_out
        # Shape: (batch_size, out)
        question_vector = self.summary_vector(question_out, question_mask,
                                              "question")

        passage_out = bert_out
        del bert_out

        # Shape: (batch_size, bert_dim)
        passage_vector = self.summary_vector(passage_out, passage_mask)

        if "arithmetic" in self.answering_abilities and self.arithmetic == "advanced":
            arithmetic_summary = self.summary_vector(passage_out, pad_mask,
                                                     "arithmetic")
            #             arithmetic_summary = self.summary_vector(question_out, question_mask, "arithmetic")

            # Shape: (batch_size, # of numbers in the passage)
            if number_indices.dim() == 3:
                number_indices = number_indices[:, :, 0].long()
            number_mask = (number_indices != -1).long()
            clamped_number_indices = util.replace_masked_values(
                number_indices, number_mask, 0)
            encoded_numbers = torch.gather(
                passage_out, 1,
                clamped_number_indices.unsqueeze(-1).expand(
                    -1, -1, passage_out.size(-1)))
            op_mask = torch.ones((batch_size, self.num_ops + 1),
                                 device=number_mask.device).long()
            options_mask = torch.cat([op_mask, number_mask], -1)
            ops = self.op_embeddings(
                torch.arange(self.num_ops + 1,
                             device=number_mask.device).expand(batch_size, -1))
            options = torch.cat([self.Wo(ops), self.Wc(encoded_numbers)], 1)

        if len(self.answering_abilities) > 1:
            # Shape: (batch_size, number_of_abilities)
            answer_ability_logits = \
                self._answer_ability_predictor(torch.cat([passage_vector, question_vector], -1))
            answer_ability_log_probs = torch.nn.functional.log_softmax(
                answer_ability_logits, -1)
            best_answer_ability = torch.argmax(answer_ability_log_probs, 1)

        if "counting" in self.answering_abilities:
            count_number, best_count_number = self._count_module(
                passage_vector)
            # count_number, select_probs = self._count_module_per_sentence(passage_vector, sentences_summary_vecs)
            # answer_as_counts = answer_as_counts.squeeze(1)
            # gold_count_mask = (answer_as_counts != -1).long()
            # count_number = util.replace_masked_values(count_number, gold_count_mask, 0)

        if "passage_span_extraction" in self.answering_abilities:
            passage_span_start_log_probs, passage_span_end_log_probs, best_passage_span = \
                self._passage_span_module(passage_out, passage_mask)

        if "question_span_extraction" in self.answering_abilities:
            question_span_start_log_probs, question_span_end_log_probs, best_question_span = \
                self._question_span_module(passage_vector, question_out, question_mask)

        if "arithmetic" in self.answering_abilities:
            if self.arithmetic == "base":
                number_mask = (number_indices[:, :, 0].long() != -1).long()
                number_sign_log_probs, best_signs_for_numbers, number_mask = \
                    self._base_arithmetic_module(passage_vector, passage_out, number_indices, number_mask)
            else:
                arithmetic_logits, best_expression = \
                    self._adv_arithmetic_module(arithmetic_summary, self.max_explen, options, options_mask, \
                                                   passage_out, pad_mask)
                shapes = arithmetic_logits.shape
                if (1 - (arithmetic_logits != arithmetic_logits)).sum() != (
                        shapes[0] * shapes[1] * shapes[2]):
                    print("bad logits")
                    arithmetic_logits = torch.rand(
                        shapes,
                        device=arithmetic_logits.device,
                        requires_grad=True)

        output_dict = {}
        del passage_out, question_out
        # If answer is given, compute the loss.
        # if answer_as_passage_spans is not None or answer_as_question_spans is not None \
        #         or answer_as_expressions is not None or answer_as_counts is not None:
        #
        #     log_marginal_likelihood_list = []
        #     regression_loss = []
        #
        #     for answering_ability in self.answering_abilities:
        #         if answering_ability == "passage_span_extraction":
        #             log_marginal_likelihood_for_passage_span = \
        #                 self._passage_span_log_likelihood(answer_as_passage_spans,
        #                                                   passage_span_start_log_probs,
        #                                                   passage_span_end_log_probs)
        #             log_marginal_likelihood_list.append(log_marginal_likelihood_for_passage_span)
        #
        #         elif answering_ability == "question_span_extraction":
        #             log_marginal_likelihood_for_question_span = \
        #                 self._question_span_log_likelihood(answer_as_question_spans,
        #                                                    question_span_start_log_probs,
        #                                                    question_span_end_log_probs)
        #             log_marginal_likelihood_list.append(log_marginal_likelihood_for_question_span)
        #
        #         elif answering_ability == "arithmetic":
        #             if self.arithmetic == "base":
        #                 log_marginal_likelihood_for_arithmetic = \
        #                     self._base_arithmetic_log_likelihood(answer_as_expressions,
        #                                                          number_sign_log_probs,
        #                                                          number_mask,
        #                                                          answer_as_expressions_extra)
        #             else:
        #                 max_explen = answer_as_expressions.shape[-1]
        #                 possible_exps = answer_as_expressions.shape[1]
        #                 limit = min(possible_exps, 1000)
        #                 log_marginal_likelihood_for_arithmetic = \
        #                     self._adv_arithmetic_log_likelihood(arithmetic_logits[:,:max_explen,:],
        #                                                         answer_as_expressions[:,:limit,:].long())
        #             log_marginal_likelihood_list.append(log_marginal_likelihood_for_arithmetic)
        #
        #         elif answering_ability == "counting":
        #             regression_loss_for_count, best_count_number = \
        #                 self._count_regression(answer_as_counts, count_number_regression)
        #             regression_loss.append(regression_loss_for_count)
        #             # TODO: This is wrong
        #             log_marginal_likelihood_list.append(torch.tensor([-1e-7] * batch_size, device=regression_loss_for_count.device))
        #
        #         else:
        #             raise ValueError(f"Unsupported answering ability: {answering_ability}")
        #
        #     if len(self.answering_abilities) > 1:
        #         # Add the ability probabilities if there are more than one abilities
        #         all_log_marginal_likelihoods = torch.stack(log_marginal_likelihood_list, dim=-1)
        #         all_log_marginal_likelihoods = all_log_marginal_likelihoods + answer_ability_log_probs
        #         marginal_log_likelihood = util.logsumexp(all_log_marginal_likelihoods)
        #
        #         if len(regression_loss) > 0:
        #             marginal_log_likelihood -= regression_loss[0]
        #     else:
        #         if len(regression_loss) > 0:
        #             marginal_log_likelihood = -regression_loss[0]
        #         else:
        #             marginal_log_likelihood = log_marginal_likelihood_list[0]   # TODO: Handle

        output_dict["loss"] = -self._count_module(passage_vector)[0].mean()
        # output_dict["loss"] = loss_utils.count_loss(answer_as_counts, count_number, select_probs, passage_mask)

        with torch.no_grad():
            # Compute the metrics and add the tokenized input to the output.
            if metadata is not None:
                output_dict["question_id"] = []
                output_dict["answer"] = []
                question_tokens = []
                passage_tokens = []
                for i in range(batch_size):
                    if len(self.answering_abilities) > 1:
                        predicted_ability_str = self.answering_abilities[
                            best_answer_ability[i]]
                    else:
                        predicted_ability_str = self.answering_abilities[0]
                    answer_json: Dict[str, Any] = {}

                    # We did not consider multi-mention answers here
                    if predicted_ability_str == "passage_span_extraction":
                        answer_json["answer_type"] = "passage_span"
                        answer_json["value"], answer_json["spans"] = \
                            self._span_prediction(question_passage_tokens[i], best_passage_span[i])
                    elif predicted_ability_str == "question_span_extraction":
                        answer_json["answer_type"] = "question_span"
                        answer_json["value"], answer_json["spans"] = \
                            self._span_prediction(question_passage_tokens[i], best_question_span[i])
                    elif predicted_ability_str == "arithmetic":  # plus_minus combination answer
                        answer_json["answer_type"] = "arithmetic"
                        original_numbers = metadata[i]['original_numbers']
                        if self.arithmetic == "base":
                            answer_json["value"], answer_json["numbers"] = \
                                self._base_arithmetic_prediction(original_numbers, number_indices[i], best_signs_for_numbers[i])
                        else:
                            answer_json["value"], answer_json["expression"] = \
                                self._adv_arithmetic_prediction(original_numbers, best_expression[i])
                    elif predicted_ability_str == "counting":
                        answer_json["answer_type"] = "count"
                        answer_json["value"], answer_json["count"] = \
                            self._count_prediction(count_number[i])
                    else:
                        raise ValueError(
                            f"Unsupported answer ability: {predicted_ability_str}"
                        )

                    output_dict["question_id"].append(
                        metadata[i]["question_id"])
                    output_dict["answer"].append(answer_json)
                    answer_annotations = metadata[i].get(
                        'answer_annotations', [])
                    if answer_annotations:
                        self._drop_metrics(answer_json["value"],
                                           answer_annotations)

        return output_dict
Beispiel #17
0
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            tags: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of gold labels.  These can either be integer
            indexes or one hot arrays of labels, so of shape ``(batch_size, num_tokens)`` or of
            shape ``(batch_size, num_tokens, num_tags)``.

        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            a distribution of the tag classes per word.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.

        """
        embedded_text_input = self.text_field_embedder(tokens)
        batch_size, sequence_length, _ = embedded_text_input.size()
        mask = get_text_field_mask(tokens)
        batch_sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        encoded_text = self.stacked_encoder(embedded_text_input,
                                            batch_sequence_lengths)

        logits = self.tag_projection_layer(encoded_text)
        reshaped_log_probs = logits.view(-1, self.num_classes)
        class_probabilities = F.softmax(reshaped_log_probs).view(
            [batch_size, sequence_length, self.num_classes])

        output_dict = {
            "logits": logits,
            "class_probabilities": class_probabilities
        }

        if tags is not None:
            # Negative log likelihood criterion takes integer labels, not one hot.
            if tags.dim() == 3:
                _, tags = tags.max(-1)
            loss = sequence_cross_entropy_with_logits(logits, tags, mask)
            for metric in self.metrics.values():
                metric(logits, tags, mask.float())
            output_dict["loss"] = loss

        return output_dict
Beispiel #18
0
def _save_ply(
    f,
    verts: torch.Tensor,
    faces: torch.LongTensor,
    verts_normals: torch.Tensor,
    ascii: bool,
    decimal_places: Optional[int] = None,
) -> None:
    """
    Internal implementation for saving 3D data to a .ply file.

    Args:
        f: File object to which the 3D data should be written.
        verts: FloatTensor of shape (V, 3) giving vertex coordinates.
        faces: LongTensor of shsape (F, 3) giving faces.
        verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
        ascii: (bool) whether to use the ascii ply format.
        decimal_places: Number of decimal places for saving if ascii=True.
    """
    assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
    assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
    assert not len(verts_normals) or (verts_normals.dim() == 2
                                      and verts_normals.size(1) == 3)

    if ascii:
        f.write(b"ply\nformat ascii 1.0\n")
    elif sys.byteorder == "big":
        f.write(b"ply\nformat binary_big_endian 1.0\n")
    else:
        f.write(b"ply\nformat binary_little_endian 1.0\n")
    f.write(f"element vertex {verts.shape[0]}\n".encode("ascii"))
    f.write(b"property float x\n")
    f.write(b"property float y\n")
    f.write(b"property float z\n")
    if verts_normals.numel() > 0:
        f.write(b"property float nx\n")
        f.write(b"property float ny\n")
        f.write(b"property float nz\n")
    f.write(f"element face {faces.shape[0]}\n".encode("ascii"))
    f.write(b"property list uchar int vertex_index\n")
    f.write(b"end_header\n")

    if not (len(verts) or len(faces)):
        warnings.warn("Empty 'verts' and 'faces' arguments provided")
        return

    vert_data = torch.cat((verts, verts_normals), dim=1).detach().numpy()
    if ascii:
        if decimal_places is None:
            float_str = "%f"
        else:
            float_str = "%" + ".%df" % decimal_places
        np.savetxt(f, vert_data, float_str)
    else:
        assert vert_data.dtype == np.float32
        if isinstance(f, BytesIO):
            # tofile only works with real files, but is faster than this.
            f.write(vert_data.tobytes())
        else:
            vert_data.tofile(f)

    faces_array = faces.detach().numpy()

    _check_faces_indices(faces, max_index=verts.shape[0])

    if len(faces_array):
        if ascii:
            np.savetxt(f, faces_array, "3 %d %d %d")
        else:
            # rows are 13 bytes: a one-byte 3 followed by three four-byte face indices.
            faces_uints = np.full((len(faces_array), 13), 3, dtype=np.uint8)
            faces_uints[:, 1:] = faces_array.astype(np.uint32).view(np.uint8)
            if isinstance(f, BytesIO):
                f.write(faces_uints.tobytes())
            else:
                faces_uints.tofile(f)
Beispiel #19
0
def sample_output(
        model: transformer.Transformer,
        input_seq: torch.LongTensor,
        eos_index: int,
        pad_index: int,
        max_len: int
) -> torch.LongTensor:
    """Samples an output sequence based on the provided input.
    
    Args:
        model (:class:`transformer.Transformer`): The model to use.
        input_seq (torch.LongTensor): The input sequence to be provided to the model. This has to be a
            (batch-size x input-seq-len)-tensor.
        eos_index (int): The index that indicates the end of a sequence.
        pad_index (int): The index that indicates a padding token in a sequence.
        max_len (int): The maximum length of the generated output.
    
    Returns:
        torch.LongTensor: The generated output sequence as (batch-size x output-seq-len)-tensor.
    """
    # sanitize args
    if not isinstance(model, transformer.Transformer):
        raise TypeError("The <model> has to be a transformer.Transformer!")
    if not isinstance(input_seq, torch.LongTensor) and not isinstance(input_seq, torch.cuda.LongTensor):
        raise TypeError("The <input_seq> has to be a LongTensor!")
    if input_seq.dim() != 2:
        raise ValueError("<input_seq> has to be a matrix!")
    if not isinstance(eos_index, int):
        raise TypeError("The <eos_index> has to be an integer!")
    if eos_index < 0 or eos_index >= model.output_size:
        raise ValueError("The <eos_index> is not a legal index in the vocabulary used by <model>!")
    if not isinstance(pad_index, int):
        raise TypeError("The <pad_index> has to be an integer!")
    if pad_index < 0 or pad_index >= model.output_size:
        raise ValueError("The <pad_index> is not a legal index in the vocabulary used by <model>!")
    if max_len is not None:
        if not isinstance(max_len, int):
            raise TypeError("<max_len> has to be an integer!")
        if max_len < 1:
            raise ValueError("<max_len> has to be > 0!")
    
    original_mode = model.training  # the original mode (train/eval) of the provided model
    batch_size = input_seq.size(0)  # number of samples in the provided input sequence
    
    # put model in evaluation mode
    model.eval()
    
    output_seq = []  # used to store the generated outputs for each position
    finished = [False] * batch_size
    
    for _ in range(max_len):
        
        # prepare the target to provide to the model
        # this is the current output with an additional final entry that is supposed to be predicted next
        # (which is why the concrete value does not matter)
        current_target = torch.cat(output_seq + [input_seq.new(batch_size, 1).zero_()], dim=1)
        
        # run the model
        probs = model(input_seq, current_target)[:, -1, :]
        
        # sample next output form the computed probabilities
        output = torch.multinomial(probs, 1)
        
        # determine which samples have been finished, and replace sampled output with padding for those that are already
        for sample_idx in range(batch_size):
            if finished[sample_idx]:
                output[sample_idx, 0] = pad_index
            elif output[sample_idx, 0].item() == eos_index:
                finished[sample_idx] = True
        
        # store created output
        output_seq.append(output)
        
        # check whether generation has been finished
        if all(finished):
            break
    
    # restore original mode of the model
    model.train(mode=original_mode)
    
    return torch.cat(output_seq, dim=1)
def align_bpe_to_words(roberta, bpe_tokens: torch.LongTensor,
                       other_tokens: List[str]):
    """
    Helper to align GPT-2 BPE to other tokenization formats (e.g., spaCy).
    Args:
        roberta (RobertaHubInterface): RoBERTa instance
        bpe_tokens (torch.LongTensor): GPT-2 BPE tokens of shape `(T_bpe)`
        other_tokens (List[str]): other tokens of shape `(T_words)`
    Returns:
        List[str]: mapping from *other_tokens* to corresponding *bpe_tokens*.
    """
    assert bpe_tokens.dim() == 1
    assert bpe_tokens[0] == 0

    def clean(text):
        return text.strip()

    # remove whitespaces to simplify alignment
    bpe_tokens = [
        roberta.task.source_dictionary.string([x]) for x in bpe_tokens
    ]
    bpe_tokens = [
        clean(roberta.bpe.decode(x) if x not in {'<s>', ''} else x)
        for x in bpe_tokens
    ]
    other_tokens = [clean(str(o)) for o in other_tokens]

    # strip leading <s>
    bpe_tokens = bpe_tokens[1:]
    # assert ''.join(bpe_tokens) == ''.join(other_tokens)

    # create alignment from every word to a list of BPE tokens
    alignment = []
    # print(bpe_tokens)
    # print(other_tokens, '\n')
    bpe_toks = filter(lambda item: item[1] != '', enumerate(bpe_tokens,
                                                            start=1))
    j, bpe_tok = next(bpe_toks)
    for other_tok in other_tokens:
        if other_tok == '':
            print("empty")
        bpe_indices = []
        while True:
            if bpe_tok == '<unk>':
                unk_tok = roberta.bpe.encode(other_tok).split()[0].replace(
                    '@@', '')
                other_tok = other_tok[len(unk_tok):]
                try:
                    j, bpe_tok = next(bpe_toks)
                except StopIteration:
                    j, bpe_tok = None, None
            if other_tok.startswith(bpe_tok):
                bpe_indices.append(j)
                other_tok = other_tok[len(bpe_tok):]
                try:
                    j, bpe_tok = next(bpe_toks)
                    # break
                except StopIteration:
                    j, bpe_tok = None, None
            elif bpe_tok.startswith(other_tok):
                # other_tok spans multiple BPE tokens
                bpe_indices.append(j)
                bpe_tok = bpe_tok[len(other_tok):]
                other_tok = ''
            else:
                raise Exception('Cannot align "{}" and "{}"'.format(
                    other_tok, bpe_tok))
            if other_tok == '':
                break
        assert len(bpe_indices) > 0
        alignment.append(bpe_indices)
    assert len(alignment) == len(other_tokens)

    return alignment
Beispiel #21
0
    def forward(
            self,  # type: ignore
            question_passage: Dict[str, torch.LongTensor],
            number_indices: torch.LongTensor,
            mask_indices: torch.LongTensor,
            num_spans: torch.LongTensor = None,
            answer_as_passage_spans: torch.LongTensor = None,
            answer_as_question_spans: torch.LongTensor = None,
            answer_as_expressions: torch.LongTensor = None,
            answer_as_expressions_extra: torch.LongTensor = None,
            answer_as_counts: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # Shape: (batch_size, seqlen)
        question_passage_tokens = question_passage["tokens"]
        # Shape: (batch_size, seqlen)
        pad_mask = question_passage["mask"]
        # Shape: (batch_size, seqlen)
        seqlen_ids = question_passage["tokens-type-ids"]

        max_seqlen = question_passage_tokens.shape[-1]
        batch_size = question_passage_tokens.shape[0]

        # Shape: (batch_size, 3)
        mask = mask_indices.squeeze(-1)
        # Shape: (batch_size, seqlen)
        cls_sep_mask = \
            torch.ones(pad_mask.shape, device=pad_mask.device).long().scatter(1, mask, torch.zeros(mask.shape, device=mask.device).long())
        # Shape: (batch_size, seqlen)
        passage_mask = seqlen_ids * pad_mask * cls_sep_mask
        # Shape: (batch_size, seqlen)
        question_mask = (1 - seqlen_ids) * pad_mask * cls_sep_mask

        tags = self.ner_tagger(question_passage_tokens)

        # Shape: (batch_size, seqlen, bert_dim)
        bert_out, _ = self.BERT(question_passage_tokens,
                                seqlen_ids,
                                pad_mask,
                                output_all_encoded_layers=False)
        # Shape: (batch_size, qlen, bert_dim)
        question_end = max(mask[:, 1])
        question_out = bert_out[:, :question_end]
        # Shape: (batch_size, qlen)
        question_mask = question_mask[:, :question_end]
        # Shape: (batch_size, out)
        question_vector = self.summary_vector(question_out, question_mask,
                                              "question")

        passage_out = bert_out
        del bert_out

        # Shape: (batch_size, bert_dim)
        passage_vector = self.summary_vector(passage_out, passage_mask)

        if "arithmetic" in self.answering_abilities and self.arithmetic == "advanced":
            arithmetic_summary = self.summary_vector(passage_out, pad_mask,
                                                     "arithmetic")
            #             arithmetic_summary = self.summary_vector(question_out, question_mask, "arithmetic")

            # Shape: (batch_size, # of numbers in the passage)
            if number_indices.dim() == 3:
                number_indices = number_indices[:, :, 0].long()
            number_mask = (number_indices != -1).long()
            clamped_number_indices = util.replace_masked_values(
                number_indices, number_mask, 0)
            encoded_numbers = torch.gather(
                passage_out, 1,
                clamped_number_indices.unsqueeze(-1).expand(
                    -1, -1, passage_out.size(-1)))
            op_mask = torch.ones((batch_size, self.num_ops + 1),
                                 device=number_mask.device).long()
            options_mask = torch.cat([op_mask, number_mask], -1)
            ops = self.op_embeddings(
                torch.arange(self.num_ops + 1,
                             device=number_mask.device).expand(batch_size, -1))
            options = torch.cat([self.Wo(ops), self.Wc(encoded_numbers)], 1)

        if len(self.answering_abilities) > 1:
            # Shape: (batch_size, number_of_abilities)
            answer_ability_logits = \
                self._answer_ability_predictor(torch.cat([passage_vector, question_vector], -1))
            answer_ability_log_probs = torch.nn.functional.log_softmax(
                answer_ability_logits, -1)
            best_answer_ability = torch.argmax(answer_ability_log_probs, 1)

        if "counting" in self.answering_abilities:
            count_number_log_probs, best_count_number = self._count_module(
                passage_vector)

        if "passage_span_extraction" in self.answering_abilities:
            passage_span_start_log_probs, passage_span_end_log_probs, best_passage_span = \
                self._passage_span_module(passage_out, passage_mask)

        if "question_span_extraction" in self.answering_abilities:
            question_span_start_log_probs, question_span_end_log_probs, best_question_span = \
                self._question_span_module(passage_vector, question_out, question_mask)

        if "arithmetic" in self.answering_abilities:
            if self.arithmetic == "base":
                number_mask = (number_indices[:, :, 0].long() != -1).long()
                number_sign_log_probs, best_signs_for_numbers, number_mask = \
                    self._base_arithmetic_module(passage_vector, passage_out, number_indices, number_mask)
            else:
                arithmetic_logits, best_expression = \
                    self._adv_arithmetic_module(arithmetic_summary, self.max_explen, options, options_mask, \
                                                   passage_out, pad_mask)
                shapes = arithmetic_logits.shape
                if (1 - (arithmetic_logits != arithmetic_logits)).sum() != (
                        shapes[0] * shapes[1] * shapes[2]):
                    print("bad logits")
                    arithmetic_logits = torch.rand(
                        shapes,
                        device=arithmetic_logits.device,
                        requires_grad=True)

        output_dict = {}
        del passage_out, question_out
        # If answer is given, compute the loss.
        if answer_as_passage_spans is not None or answer_as_question_spans is not None \
                or answer_as_expressions is not None or answer_as_counts is not None:

            log_marginal_likelihood_list = []

            for answering_ability in self.answering_abilities:
                if answering_ability == "passage_span_extraction":
                    log_marginal_likelihood_for_passage_span = \
                        self._passage_span_log_likelihood(answer_as_passage_spans,
                                                          passage_span_start_log_probs,
                                                          passage_span_end_log_probs)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_passage_span)

                elif answering_ability == "question_span_extraction":
                    log_marginal_likelihood_for_question_span = \
                        self._question_span_log_likelihood(answer_as_question_spans,
                                                           question_span_start_log_probs,
                                                           question_span_end_log_probs)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_question_span)

                elif answering_ability == "arithmetic":
                    if self.arithmetic == "base":
                        log_marginal_likelihood_for_arithmetic = \
                            self._base_arithmetic_log_likelihood(answer_as_expressions,
                                                                 number_sign_log_probs,
                                                                 number_mask,
                                                                 answer_as_expressions_extra)
                    else:
                        max_explen = answer_as_expressions.shape[-1]
                        possible_exps = answer_as_expressions.shape[1]
                        limit = min(possible_exps, 1000)
                        log_marginal_likelihood_for_arithmetic = \
                            self._adv_arithmetic_log_likelihood(arithmetic_logits[:,:max_explen,:],
                                                                answer_as_expressions[:,:limit,:].long())
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_arithmetic)

                elif answering_ability == "counting":
                    log_marginal_likelihood_for_count = \
                        self._count_log_likelihood(answer_as_counts,
                                                   count_number_log_probs)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_count)

                else:
                    raise ValueError(
                        f"Unsupported answering ability: {answering_ability}")

            if len(self.answering_abilities) > 1:
                # Add the ability probabilities if there are more than one abilities
                all_log_marginal_likelihoods = torch.stack(
                    log_marginal_likelihood_list, dim=-1)
                all_log_marginal_likelihoods = all_log_marginal_likelihoods + answer_ability_log_probs
                marginal_log_likelihood = util.logsumexp(
                    all_log_marginal_likelihoods)
            else:
                marginal_log_likelihood = log_marginal_likelihood_list[0]

            output_dict["loss"] = -marginal_log_likelihood.mean()
        with torch.no_grad():
            # Compute the metrics and add the tokenized input to the output.
            if metadata is not None:
                output_dict["question_id"] = []
                output_dict["answer"] = []
                question_tokens = []
                passage_tokens = []
                for i in range(batch_size):
                    if len(self.answering_abilities) > 1:
                        predicted_ability_str = self.answering_abilities[
                            best_answer_ability[i]]
                    else:
                        predicted_ability_str = self.answering_abilities[0]
                    answer_json: Dict[str, Any] = {}

                    # We did not consider multi-mention answers here
                    if predicted_ability_str == "passage_span_extraction":
                        answer_json["answer_type"] = "passage_span"
                        answer_json["value"], answer_json["spans"] = \
                            self._span_prediction(question_passage_tokens[i], best_passage_span[i])
                    elif predicted_ability_str == "question_span_extraction":
                        answer_json["answer_type"] = "question_span"
                        answer_json["value"], answer_json["spans"] = \
                            self._span_prediction(question_passage_tokens[i], best_question_span[i])
                    elif predicted_ability_str == "arithmetic":  # plus_minus combination answer
                        answer_json["answer_type"] = "arithmetic"
                        original_numbers = metadata[i]['original_numbers']
                        if self.arithmetic == "base":
                            answer_json["value"], answer_json["numbers"] = \
                                self._base_arithmetic_prediction(original_numbers, number_indices[i], best_signs_for_numbers[i])
                        else:
                            answer_json["value"], answer_json["expression"] = \
                                self._adv_arithmetic_prediction(original_numbers, best_expression[i])
                    elif predicted_ability_str == "counting":
                        answer_json["answer_type"] = "count"
                        answer_json["value"], answer_json["count"] = \
                            self._count_prediction(best_count_number[i])
                    else:
                        raise ValueError(
                            f"Unsupported answer ability: {predicted_ability_str}"
                        )

                    output_dict["question_id"].append(
                        metadata[i]["question_id"])
                    output_dict["answer"].append(answer_json)
                    answer_annotations = metadata[i].get(
                        'answer_annotations', [])
                    if answer_annotations:
                        self._drop_metrics(answer_json["value"],
                                           answer_annotations)

        return output_dict
Beispiel #22
0
    def forward(self, batch: torch.LongTensor) -> torch.FloatTensor:
        """Computes the loss function.

        Args:
            batch (torch.LongTensor): A batch of training data, as (batch-size x max-seq-len)-tensor.

        Returns:
            torch.FloatTensor: The computed loss.
        """
        # sanitize args
        insanity.sanitize_type("batch", batch, torch.Tensor)
        if batch.dtype != torch.int64:
            raise TypeError("<batch> has to be a LongTensor!")
        if batch.dim() != 2:
            raise ValueError("<batch> has to be a 2d tensor!")
        
        # create the padding mask to use
        padding_mask = util.create_padding_mask(batch, self._pad_index)
        
        # create a tensor of indices, which is used to retrieve the according positional embeddings below
        index_seq = batch.new(range(batch.size(1))).unsqueeze(0).expand(batch.size(0), -1)
        
        # compute the sequence lengths for all samples in the batch
        seq_len = (batch != self._pad_index).sum(dim=1).cpu().numpy().tolist()
        
        # randomly choose the tokens to compute predictions for
        pred_mask = padding_mask.new(*batch.size()).zero_().long()  # all tokens being predicted
        mask_mask = padding_mask.new(*batch.size()).zero_().long()  # token replaced with <MASK>
        random_mask = padding_mask.new(*batch.size()).zero_().long()  # tokens replace with random tokens
        for sample_idx, sample_len in enumerate(seq_len):  # iterate over all samples in the batch
            
            # determine how many tokens to computed predictions for
            num_pred = int(math.ceil(sample_len * self._prediction_rate))  # num of tokens predictions are computed for
            num_mask = int(math.floor(num_pred * self._mask_rate))  # num of tokens replaced with <MASK>
            num_random = int(math.ceil(num_pred * self._random_rate))  # num of tokens randomly replaced
            
            # randomly select indices to compute predictions for
            pred_indices = list(range(sample_len))
            random.shuffle(pred_indices)
            pred_indices = pred_indices[:num_pred]
            
            # prepare the <MASK>-mask
            for token_idx in pred_indices[:num_mask]:
                pred_mask[sample_idx, token_idx] = 1
                mask_mask[sample_idx, token_idx] = 1
            
            # prepare the random-mask
            for token_idx in pred_indices[num_mask:(num_mask + num_random)]:
                pred_mask[sample_idx, token_idx] = 1
                random_mask[sample_idx, token_idx] = 1
            
            # remaining tokens that predictions are computed for are left untouched
            for token_idx in pred_indices[(num_mask + num_random):]:
                pred_mask[sample_idx, token_idx] = 1
        
        # replace predicted tokens in the batch appropriately
        masked_batch = (
                batch * (1 - mask_mask) * (1 - random_mask) +
                mask_mask * batch.new(*batch.size()).fill_(self._mask_index) +
                random_mask * (batch.new(*batch.size()).double().uniform_() * self._word_emb.num_embeddings).long()
        )
        
        # embed the batch
        masked_batch = self._word_emb(masked_batch) + self._pos_emb(index_seq)
        
        # encode sequence in the batch using BERT
        enc = self._model(masked_batch, padding_mask)
        
        # turn encodings, the target token indices (that we seek to predict), and the prediction mask, into matrices,
        # such that each row corresponds with one token
        enc = enc.view(enc.size(0) * enc.size(1), enc.size(2))
        target = batch.view(-1)
        pred_mask = pred_mask.view(-1)
        
        # turn the prediction mask into a tensor of indices (to select below)
        pred_mask = pred_mask.new(np.where(pred_mask.detach().cpu().numpy())[0])
        
        # fetch embeddings and target values of those tokens that are being predicted
        enc = enc.index_select(0, pred_mask)
        target = target.index_select(0, pred_mask)
        
        # compute predictions for each encoded token + the according loss
        pred = self._output_layer(enc)
        loss = self._loss(pred, target)
        
        return loss
Beispiel #23
0
    def __call__(self, prediction_labels: torch.Tensor,
                 gold_labels: torch.Tensor,
                 mask: torch.LongTensor) -> Dict[str, float]:
        """
        计算 metric.

        :param prediction_labels: 预测结果 shape: (B,), 这是模型解码成label的结果,注意是 label,不是 logits
        :param gold_labels: 实际结果 shape: (B,)
        :param mask: mask, shape: (B,)
        :return 每一个label的f1值。这包括 precision, recall, f1。具体结果类似:

        {"precision_[label]": [value],
         "recall_[label]" : [value],
         "f1-measure_[label]": [value],
         "precision-overall": [value],
         "recall-overall": [value],
         "f1-measure-overall": [value]}

         说明: "*-overall" 表示的是所有 命中 在初始化参数的 labels 的综合 metric 值。
         这是有必要的,作为一个综合的值作为统一衡量。
        """

        assert prediction_labels.dim() == 1, "predictions shape 是 (B,)"
        assert gold_labels.dim() == 1, "gold_labels shape 是 (B,)"

        if mask is not None:
            assert mask.dim() == 1, "mask shape 是 (B,)"

        # 转换到 cpu 进行计算
        prediction_labels, gold_labels = prediction_labels.detach().cpu(
        ), gold_labels.detach().cpu()

        if mask is not None:
            bool_mask = mask.detach().cpu().bool()

            prediction_labels = prediction_labels.masked_select(bool_mask)
            gold_labels = gold_labels.masked_select(bool_mask)

        # 当前 batch 下的 true_positives
        true_positives = defaultdict(int)
        false_positives = defaultdict(int)
        false_negatives = defaultdict(int)

        for label, label_index in zip(self._labels, self._label_indices):
            num_prediction = (prediction_labels == label_index).sum().item()
            num_golden = (gold_labels == label_index).sum().item()

            # 计算 true positives
            label_mask = (prediction_labels == label_index)
            label_predictions = prediction_labels.masked_select(label_mask)
            label_gold = gold_labels.masked_select(label_mask)

            true_positives[label] = (
                label_predictions == label_gold).sum().item()
            false_positives[label] = num_prediction - true_positives[label]
            false_negatives[label] = num_golden - true_positives[label]

        for k, v in true_positives.items():
            self._true_positives[k] += v

        for k, v in false_positives.items():
            self._false_positives[k] += v

        for k, v in false_negatives.items():
            self._false_negatives[k] += v

        return self._metric(true_positives=true_positives,
                            false_positives=false_positives,
                            false_negatives=false_negatives)
    def forward(self, input_seq: torch.LongTensor,
                target: torch.LongTensor) -> torch.FloatTensor:
        """Runs the Transformer.
        
        The Transformer expects both an input as well as a target sequence to be provided, and yields a probability
        distribution over all possible output tokens for each position in the target sequence.
        
        Args:
            input_seq (torch.LongTensor): The input sequence as (batch-size x input-seq-len)-tensor.
            target (torch.LongTensor): The target sequence as (batch-size x target-seq-len)-tensor.
        
        Returns:
            torch.FloatTensor: The computed probabilities for each position in ``target`` as a
                (batch-size x target-seq-len x output-size)-tensor.
        """
        # sanitize args
        if not isinstance(input_seq, torch.LongTensor) and not isinstance(
                input_seq, torch.cuda.LongTensor):
            raise TypeError("<input_seq> has to be a LongTensor!")
        if input_seq.dim() != 2:
            raise ValueError("<input_seq> has to have 2 dimensions!")
        if not isinstance(target, torch.LongTensor) and not isinstance(
                target, torch.cuda.LongTensor):
            raise TypeError("<target> has to be a LongTensor!")
        if target.dim() != 2:
            raise ValueError("<target> has to have 2 dimensions!")

        # create a tensor of indices, which is used to retrieve the according positional embeddings below
        index_seq = input_seq.new(range(
            input_seq.size(1))).unsqueeze(0).expand(input_seq.size(0), -1)

        # create padding mask for input
        padding_mask = util.create_padding_mask(input_seq, self._pad_index)

        # embed the provided input
        input_seq = self._word_emb(input_seq) + self._positional_emb(index_seq)

        # project input to the needed size
        input_seq = self._input_projection(input_seq)

        # run the encoder
        input_seq = self._encoder(input_seq, padding_mask=padding_mask)

        # create a tensor of indices, which is used to retrieve the positional embeddings for the targets below
        index_seq = target.new(range(target.size(1))).unsqueeze(0).expand(
            target.size(0), -1)

        # embed the provided targets
        target = self._word_emb(target) + self._positional_emb(index_seq)

        # project target to the needed size
        target = self._input_projection(target)

        # run the decoder
        output = self._decoder(input_seq, target, padding_mask=padding_mask)

        # project output to the needed size
        output = self._output_projection(output)

        # compute softmax
        return functional.softmax(output, dim=2)
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                spans: torch.LongTensor,
                metadata: List[Dict[str, Any]],
                pos_tags: Dict[str, torch.LongTensor] = None,
                span_labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        spans : ``torch.LongTensor``, required.
            A tensor of shape ``(batch_size, num_spans, 2)`` representing the
            inclusive start and end indices of all possible spans in the sentence.
        metadata : List[Dict[str, Any]], required.
            A dictionary of metadata for each batch element which has keys:
                tokens : ``List[str]``, required.
                    The original string tokens in the sentence.
                gold_tree : ``nltk.Tree``, optional (default = None)
                    Gold NLTK trees for use in evaluation.
                pos_tags : ``List[str]``, optional.
                    The POS tags for the sentence. These can be used in the
                    model as embedded features, but they are passed here
                    in addition for use in constructing the tree.
        pos_tags : ``torch.LongTensor``, optional (default = None)
            The output of a ``SequenceLabelField`` containing POS tags.
        span_labels : ``torch.LongTensor``, optional (default = None)
            A torch tensor representing the integer gold class labels for all possible
            spans, of shape ``(batch_size, num_spans)``.

        Returns
        -------
        An output dictionary consisting of:
        class_probabilities : ``torch.FloatTensor``
            A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
            representing a distribution over the label classes per span.
        spans : ``torch.LongTensor``
            The original spans tensor.
        tokens : ``List[List[str]]``, required.
            A list of tokens in the sentence for each element in the batch.
        pos_tags : ``List[List[str]]``, required.
            A list of POS tags in the sentence for each element in the batch.
        num_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size), representing the lengths of non-padded spans
            in ``enumerated_spans``.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        embedded_text_input = self.text_field_embedder(tokens)
        if pos_tags is not None and self.pos_tag_embedding is not None:
            embedded_pos_tags = self.pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1)
        elif self.pos_tag_embedding is not None:
            raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(tokens)
        # Looking at the span start index is enough to know if
        # this is padding or not. Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()
        if span_mask.dim() == 1:
            # This happens if you use batch_size 1 and encounter
            # a length 1 sentence in PTB, which do exist. -.-
            span_mask = span_mask.unsqueeze(-1)
        if span_labels is not None and span_labels.dim() == 1:
            span_labels = span_labels.unsqueeze(-1)

        num_spans = get_lengths_from_binary_sequence_mask(span_mask)

        encoded_text = self.encoder(embedded_text_input, mask)
        span_representations = self.span_extractor(encoded_text, spans, mask, span_mask)
        if self.feedforward_layer is not None:
            span_representations = self.feedforward_layer(span_representations)
        logits = self.tag_projection_layer(span_representations)
        class_probabilities = last_dim_softmax(logits, span_mask.unsqueeze(-1))

        output_dict = {
                "class_probabilities": class_probabilities,
                "spans": spans,
                "tokens": [meta["tokens"] for meta in metadata],
                "pos_tags": [meta.get("pos_tags") for meta in metadata],
                "num_spans": num_spans
        }
        if span_labels is not None:
            loss = sequence_cross_entropy_with_logits(logits, span_labels, span_mask)
            self.tag_accuracy(class_probabilities, span_labels, span_mask)
            output_dict["loss"] = loss

        # The evalb score is expensive to compute, so we only compute
        # it for the validation and test sets.
        batch_gold_trees = [meta.get("gold_tree") for meta in metadata]
        if all(batch_gold_trees) and self._evalb_score is not None and not self.training:
            gold_pos_tags: List[List[str]] = [list(zip(*tree.pos()))[1]
                                              for tree in batch_gold_trees]
            predicted_trees = self.construct_trees(class_probabilities.cpu().data,
                                                   spans.cpu().data,
                                                   num_spans.data,
                                                   output_dict["tokens"],
                                                   gold_pos_tags)
            self._evalb_score(predicted_trees, batch_gold_trees)

        return output_dict
Beispiel #26
0
    def _compute_score(self, emissions: torch.Tensor, tags1: torch.LongTensor,
                       tags2: torch.LongTensor,
                       sims: torch.LongTensor) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # tags: (seq_length, batch_size)
        assert emissions.dim() == 3 and tags1.dim() == 2 and tags2.dim(
        ) == 2 and sims.dim() == 2
        assert emissions.shape[:2] == tags1.shape == tags2.shape == sims.shape
        assert emissions.size(2) == self.num_tags

        seq_length, batch_size = tags1.shape

        # Start transition score and first emission
        # shape: (batch_size,)
        if sims[0]:  #当前标签一样
            score = self.start_transitions[tags1[0]]
            score = score + emissions[0, torch.arange(batch_size), tags1[0]]
        elif sims[1]:  #当前标签不一样,下一个词标签一样
            score1 = self.start_transitions[tags1[0]]
            score2 = self.start_transitions[tags2[0]]
            score1 = score1 + emissions[0, torch.arange(batch_size), tags1[0]]
            score2 = score2 + emissions[0, torch.arange(batch_size), tags2[0]]
            score = torch.logsumexp(torch.stack((score1, score2), 1), dim=1)
        else:  #当前标签不一样,下一个词标签也不一样
            score1 = self.start_transitions[tags1[0]]
            score2 = self.start_transitions[tags2[0]]
            score1 = score1 + emissions[0, torch.arange(batch_size), tags1[0]]
            score2 = score2 + emissions[0, torch.arange(batch_size), tags2[0]]

        for i in range(1, seq_length - 1):
            # Transition score to next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            # Emission score for next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            if sims[i]:  #当前标签一样
                score = score + self.transitions[tags1[i - 1], tags1[i]]
                score = score + emissions[i,
                                          torch.arange(batch_size), tags1[i]]
            elif sims[i - 1] and sims[i + 1]:  #上一个词标签一样,当前标签不一样,下一个词标签一样
                score1 = score + self.transitions[tags1[i - 1], tags1[i]]
                score1 = score1 + emissions[i,
                                            torch.arange(batch_size), tags1[i]]
                score2 = score + self.transitions[tags2[i - 1], tags2[i]]
                score2 = score2 + emissions[i,
                                            torch.arange(batch_size), tags2[i]]
                score = torch.logsumexp(torch.stack((score1, score2), 1),
                                        dim=1)
            elif sims[i - 1]:  #上一个词标签一样,当前标签不一样,下一个词标签不一样
                score1 = score + self.transitions[tags1[i - 1], tags1[i]]
                score1 = score1 + emissions[i,
                                            torch.arange(batch_size), tags1[i]]
                score2 = score + self.transitions[tags2[i - 1], tags2[i]]
                score2 = score2 + emissions[i,
                                            torch.arange(batch_size), tags2[i]]
            elif sims[i + 1]:  #上一个词标签不一样,当前标签不一样,下一个词标签一样
                score1 = score1 + self.transitions[
                    tags1[i - 1], tags1[i]] + self.transitions[tags2[i - 1],
                                                               tags1[i]]
                score1 = score1 + emissions[i,
                                            torch.arange(batch_size), tags1[i]]
                score2 = score2 + self.transitions[
                    tags2[i - 1], tags2[i]] + self.transitions[tags1[i - 1],
                                                               tags2[i]]
                score2 = score2 + emissions[i,
                                            torch.arange(batch_size), tags2[i]]
                score = torch.logsumexp(torch.stack((score1, score2), 1),
                                        dim=1)
            else:  #上一个词标签不一样,当前标签不一样,下一个词标签也不一样
                score1 = score1 + self.transitions[
                    tags1[i - 1], tags1[i]] + self.transitions[tags2[i - 1],
                                                               tags1[i]]
                score1 = score1 + emissions[i,
                                            torch.arange(batch_size), tags1[i]]
                score2 = score2 + self.transitions[
                    tags2[i - 1], tags2[i]] + self.transitions[tags1[i - 1],
                                                               tags2[i]]
                score2 = score2 + emissions[i,
                                            torch.arange(batch_size), tags2[i]]

        # End transition score
        # shape: (batch_size,)
        seq_ends = seq_length - 1
        if sims[seq_ends]:  #当前标签一样
            score = score + self.transitions[tags1[seq_ends - 1],
                                             tags1[seq_ends]]
            score = score + emissions[seq_ends,
                                      torch.arange(batch_size),
                                      tags1[seq_ends]]
            # shape: (batch_size,)
            last_tags = tags1[seq_ends, torch.arange(batch_size)]
            # shape: (batch_size,)
            score = score + self.end_transitions[last_tags]
        elif sims[seq_ends - 1]:  #上一个词标签一样,当前标签不一样
            score1 = score + self.transitions[tags1[seq_ends - 1],
                                              tags1[seq_ends]]
            score1 = score1 + emissions[seq_ends,
                                        torch.arange(batch_size),
                                        tags1[seq_ends]]
            last_tags1 = tags1[seq_ends, torch.arange(batch_size)]
            score1 = score1 + self.end_transitions[last_tags1]

            score2 = score + self.transitions[tags2[seq_ends - 1],
                                              tags2[seq_ends]]
            score2 = score2 + emissions[seq_ends,
                                        torch.arange(batch_size),
                                        tags2[seq_ends]]
            last_tags2 = tags2[seq_ends, torch.arange(batch_size)]
            score2 = score2 + self.end_transitions[last_tags2]
            score = torch.logsumexp(torch.stack((score1, score2), 1), dim=1)
        else:  #上一个词标签不一样,当前标签不一样
            score1 = score1 + self.transitions[
                tags1[seq_ends - 1],
                tags1[seq_ends]] + self.transitions[tags2[seq_ends - 1],
                                                    tags1[seq_ends]]
            score1 = score1 + emissions[seq_ends,
                                        torch.arange(batch_size),
                                        tags1[seq_ends]]
            last_tags1 = tags1[seq_ends, torch.arange(batch_size)]
            score1 = score1 + self.end_transitions[last_tags1]

            score2 = score2 + self.transitions[
                tags2[seq_ends - 1],
                tags2[seq_ends]] + self.transitions[tags1[seq_ends - 1],
                                                    tags2[seq_ends]]
            score2 = score2 + emissions[seq_ends,
                                        torch.arange(batch_size),
                                        tags2[seq_ends]]
            last_tags2 = tags2[seq_ends, torch.arange(batch_size)]
            score2 = score2 + self.end_transitions[last_tags2]
            score = torch.logsumexp(torch.stack((score1, score2), 1), dim=1)
        return score
Beispiel #27
0
    def compute_partial_decoded_loss(
        self,
        batch: Batch,
        latent: torch.Tensor,
        encoder_states: Tuple[torch.Tensor, ...],
        cand_vecs: torch.LongTensor,
        label_inds: torch.LongTensor,
    ) -> torch.Tensor:
        """
        Compute partial loss from decoding outputs.

        Here, we consider each partially decoded sequence as a separate
        item from which to compute multiobjective scores.

        :param batch:
            batch being considered
        :param latent:
            decoder output representations
        :param encoder_states:
            encoder output representations
        :param cand_vecs:
            character candidate vectors
        :param label_inds:
            list of indices indicating which character is correct in the character candidates

        :return partial_loss:
            return loss for each batch item as a sum of the partial losses.
        """
        assert self.opt['multiobjective_latent_representation'] == 'decoder_final_layer'
        assert latent.dim() == 3 and latent.size(0) == cand_vecs.size(0)
        bsz, seq_len, dim = latent.size()
        seq_lens = []
        partial_char_losses = []
        seq_scores = []
        stride_length = 2
        for stride in range(0, bsz, stride_length):  # arbitrary stride for now
            # Compute new batches of items; latent reps, candidate vectors, etc.
            end_idx = min(stride + stride_length, bsz)
            new_bsz = batch.label_vec[stride:end_idx].ne(self.NULL_IDX).sum().item()
            new_latent = latent.new(new_bsz, seq_len, dim).fill_(0)
            new_cand_vecs = cand_vecs.new(new_bsz, *cand_vecs.shape[1:]).fill_(
                self.NULL_IDX
            )
            if new_cand_vecs.dim() == 2:
                new_cand_vecs = new_cand_vecs.unsqueeze(1).repeat(
                    1, cand_vecs.size(0), 1
                )
            new_label_inds = label_inds[stride:end_idx].new(new_bsz).fill_(0)

            # For each batch item in the stride, we compute seq_length examples
            # where each example represents a partial output of the decoder.
            offset = 0
            for i in range(stride, end_idx):
                cand_vecs_i = cand_vecs if cand_vecs.dim() == 2 else cand_vecs[i]
                seq_len_i = batch.label_vec[i].ne(self.NULL_IDX).sum().item()
                seq_lens.append(seq_len_i)
                for j in range(seq_len_i):
                    new_latent[offset + j, 0 : j + 1, :] = latent[
                        i : i + 1, 0 : j + 1, :
                    ]
                new_cand_vecs[offset : offset + seq_len_i] = cand_vecs_i
                new_label_inds[offset : offset + seq_len_i] = label_inds[
                    i : i + 1
                ].repeat(seq_len_i)
                offset += seq_len_i

            assert isinstance(new_cand_vecs, torch.LongTensor)
            seq_score = self.get_multiobjective_output(
                new_latent, encoder_states, new_cand_vecs, 'partial'
            )
            partial_char_losses.append(
                self.multiobj_criterion(seq_score, new_label_inds)
            )
            seq_scores.append(seq_score)
        partial_char_loss = torch.cat(partial_char_losses, dim=0)
        seq_scores = torch.cat(seq_scores, dim=0)
        partial_char_loss_metric = partial_char_loss.new(bsz).fill_(0)
        offset = 0
        partial_char_scores = torch.zeros(
            batch.batchsize,
            batch.batchsize if cand_vecs.dim() == 2 else cand_vecs.size(1),
        ).to(latent)
        for i in range(bsz):
            partial_char_loss_metric[i] = partial_char_loss[
                offset : offset + seq_lens[i]
            ].mean()
            partial_char_scores[i] = seq_scores[
                partial_char_loss[offset : offset + seq_lens[i]].argmin()
            ]
        self.compute_multiobj_metrics(
            partial_char_loss_metric, partial_char_scores, label_inds, prefix='partial'
        )
        return partial_char_loss
Beispiel #28
0
    def forward(self,
                sentence: LongTensor,
                entity_tag: LongTensor,
                event_type: LongTensor,
                metadata: Dict = None) -> EventModelOutputs:
        """
        模型运行
        :param sentence: shape: (B, SeqLen), 句子的 index tensor
        :param entity_tag: shape: (B, SeqLen), 句子的 实体 index tensor
        :param event_type: shape: (B,), event type 的 tensor
        :param metadata: metadata 数据,不参与模型运算
        """

        assert sentence.dim(
        ) == 2, f"Sentence 的维度 {sentence.dim()} !=2, 应该是(B, seq_len)"
        assert entity_tag.dim(
        ) == 2, f"entity_tag 维度 {entity_tag.dim()} != 2, 应该是 (B, seq_len)"
        assert event_type.dim(
        ) == 1, f"event_type 维度 {event_type.dim()} != 1, 应该是 (B,)"

        batch_size = sentence.size(0)
        seq_len = sentence.size(1)

        # sentence, entity_tag 使用的是同一个 mask
        mask = nn_util.sequence_mask(
            sentence, self._sentence_vocab.index(self._sentence_vocab.padding))
        assert mask.shape == (batch_size, seq_len), f"mask 维度是: (B, seq_len)"

        # shape: B * SeqLen * sentence_embedding_dim
        sentence_embedding = self._sentence_embedder(sentence)

        assert sentence_embedding.shape == (batch_size, seq_len,
                                            self._sentence_embedding_dim)

        # shape: B * SeqLen * entity_tag_embedding_dim
        entity_tag_embedding = self._entity_tag_embedder(entity_tag)

        assert entity_tag_embedding.shape == (batch_size, seq_len,
                                              self._entity_tag_embedding_dim)

        # shape: B * SeqLen * InputSize, InputSize = sentence_embedding_dim + entity_tag_embedding_dim
        sentence_embedding = torch.cat(
            (sentence_embedding, entity_tag_embedding), dim=-1)
        assert sentence_embedding.shape, (batch_size, seq_len,
                                          self._sentence_embedding_dim +
                                          self._entity_tag_embedding_dim)
        # 使用 mask 计算 sentence 实际长度, shape: (B,)
        sentence_length = mask.long().sum(dim=-1)

        assert sentence_length.shape == (batch_size, )

        # 使用 lstm sequence encoder 进行 encoder
        packed_sentence_embedding = pack_padded_sequence(
            input=sentence_embedding,
            lengths=sentence_length,
            batch_first=True,
            enforce_sorted=False)

        packed_sequence, (h_n, c_n) = self._lstm(packed_sentence_embedding)

        # Tuple, sentence: shape: B * SeqLen * InputSize 和 sentence length
        (sentence_encoding, _) = pad_packed_sequence(packed_sequence,
                                                     batch_first=True)

        assert sentence_encoding.shape == (batch_size, seq_len,
                                           self._lstm_hidden_size)

        # shape: B * InputSize
        event_type_embedding_1: Tensor = self._event_type_embedder_1(
            event_type)
        assert event_type_embedding_1.shape == (batch_size,
                                                self._event_type_embedding_dim)

        # attention
        # shape: B * InputSize * 1
        event_type_embedding_1_tmp = event_type_embedding_1.unsqueeze(-1)
        assert event_type_embedding_1_tmp.shape == (
            batch_size, self._event_type_embedding_dim, 1)

        # shape: (B * SeqLen * InputSize) bmm (B * InputSize * 1) = B * SeqLen * 1
        attention_logits = sentence_encoding.bmm(event_type_embedding_1_tmp)

        # shape: B * SeqLen
        attention_logits = attention_logits.squeeze(-1)

        assert attention_logits.shape == (batch_size, seq_len)

        # Shape: B * SeqLen
        tmp_attention_logits = torch.exp(attention_logits) * mask.float()

        # Shape: B * Seqlen
        tmp_attenttion_logits_sum = torch.sum(tmp_attention_logits,
                                              dim=-1,
                                              keepdim=True)

        # Shape: B * SeqLen
        attention = tmp_attention_logits / tmp_attenttion_logits_sum

        assert attention.shape == (batch_size, seq_len)

        # Score1 计算, Shape: B * 1
        score1 = torch.sum(attention_logits * attention, dim=-1, keepdim=True)

        assert score1.shape == (batch_size, 1)

        score1 = score1.squeeze(dim=-1)

        # global score

        # 获取最后一个hidden, shape: B * INPUT_SIZE
        hidden_last = h_n.squeeze(dim=0)
        assert hidden_last.shape == (batch_size, self._lstm_hidden_size)

        # event type 2, shape: B * INPUT_SIZE
        event_type_embedding_2: Tensor = self._event_type_embedder_2(
            event_type)

        assert event_type_embedding_2.shape == (batch_size,
                                                self._event_type_embedding_dim)

        # shape: B * INPUT_SIZE
        tmp = hidden_last * event_type_embedding_2

        # shape: B * 1
        score2 = torch.sum(tmp, dim=-1, keepdim=True)

        assert score2.shape == (batch_size, 1)

        score2 = score2.squeeze(dim=-1)

        # 最终的score, B
        score = score1 * self._alpha + score2 * (1 - self._alpha)
        assert score.shape == (batch_size, )

        if self._activate_score:  # 使用 sigmoid函数激活
            score = torch.sigmoid(score)

        return EventModelOutputs(logits=score, event_type=event_type)
    def forward(
        self,  # type: ignore
        tokens: TextFieldTensors,
        spans: torch.LongTensor,
        metadata: List[Dict[str, Any]],
        pos_tags: TextFieldTensors = None,
        span_labels: torch.LongTensor = None,
    ) -> Dict[str, torch.Tensor]:

        """
        # Parameters

        tokens : `TextFieldTensors`, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. This output is a dictionary mapping keys to `TokenIndexer`
            tensors.  At its most basic, using a `SingleIdTokenIndexer` this is : `{"tokens":
            Tensor(batch_size, num_tokens)}`. This dictionary will have the same keys as were used
            for the `TokenIndexers` when you created the `TextField` representing your
            sequence.  The dictionary is designed to be passed directly to a `TextFieldEmbedder`,
            which knows how to combine different word representations into a single vector per
            token in your input.
        spans : `torch.LongTensor`, required.
            A tensor of shape `(batch_size, num_spans, 2)` representing the
            inclusive start and end indices of all possible spans in the sentence.
        metadata : `List[Dict[str, Any]]`, required.
            A dictionary of metadata for each batch element which has keys:
                tokens : `List[str]`, required.
                    The original string tokens in the sentence.
                gold_tree : `nltk.Tree`, optional (default = `None`)
                    Gold NLTK trees for use in evaluation.
                pos_tags : `List[str]`, optional.
                    The POS tags for the sentence. These can be used in the
                    model as embedded features, but they are passed here
                    in addition for use in constructing the tree.
        pos_tags : `torch.LongTensor`, optional (default = `None`)
            The output of a `SequenceLabelField` containing POS tags.
        span_labels : `torch.LongTensor`, optional (default = `None`)
            A torch tensor representing the integer gold class labels for all possible
            spans, of shape `(batch_size, num_spans)`.

        # Returns

        An output dictionary consisting of:

        class_probabilities : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_spans, span_label_vocab_size)`
            representing a distribution over the label classes per span.
        spans : `torch.LongTensor`
            The original spans tensor.
        tokens : `List[List[str]]`, required.
            A list of tokens in the sentence for each element in the batch.
        pos_tags : `List[List[str]]`, required.
            A list of POS tags in the sentence for each element in the batch.
        num_spans : `torch.LongTensor`, required.
            A tensor of shape (batch_size), representing the lengths of non-padded spans
            in `enumerated_spans`.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """
        embedded_text_input = self.text_field_embedder(tokens)
        if pos_tags is not None and self.pos_tag_embedding is not None:
            embedded_pos_tags = self.pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1)
        elif self.pos_tag_embedding is not None:
            raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(tokens)
        # Looking at the span start index is enough to know if
        # this is padding or not. Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1)
        if span_mask.dim() == 1:
            # This happens if you use batch_size 1 and encounter
            # a length 1 sentence in PTB, which do exist. -.-
            span_mask = span_mask.unsqueeze(-1)
        if span_labels is not None and span_labels.dim() == 1:
            span_labels = span_labels.unsqueeze(-1)

        num_spans = get_lengths_from_binary_sequence_mask(span_mask)

        encoded_text = self.encoder(embedded_text_input, mask)

        span_representations = self.span_extractor(encoded_text, spans, mask, span_mask)

        if self.feedforward_layer is not None:
            span_representations = self.feedforward_layer(span_representations)

        logits = self.tag_projection_layer(span_representations)
        class_probabilities = masked_softmax(logits, span_mask.unsqueeze(-1))

        output_dict = {
            "class_probabilities": class_probabilities,
            "spans": spans,
            "tokens": [meta["tokens"] for meta in metadata],
            "pos_tags": [meta.get("pos_tags") for meta in metadata],
            "num_spans": num_spans,
        }
        if span_labels is not None:
            loss = sequence_cross_entropy_with_logits(logits, span_labels, span_mask)
            self.tag_accuracy(class_probabilities, span_labels, span_mask)
            output_dict["loss"] = loss

        # The evalb score is expensive to compute, so we only compute
        # it for the validation and test sets.
        batch_gold_trees = [meta.get("gold_tree") for meta in metadata]
        if all(batch_gold_trees) and self._evalb_score is not None and not self.training:
            gold_pos_tags: List[List[str]] = [
                list(zip(*tree.pos()))[1] for tree in batch_gold_trees
            ]
            predicted_trees = self.construct_trees(
                class_probabilities.cpu().data,
                spans.cpu().data,
                num_spans.data,
                output_dict["tokens"],
                gold_pos_tags,
            )
            self._evalb_score(predicted_trees, batch_gold_trees)

        return output_dict
Beispiel #30
0
    def forward(self, sentence: LongTensor, category: LongTensor) -> ACSAModelOutputs:
        """
        模型运行
        :param sentence: 句子的 index tensor
        :param category: category index tensor
        :return:
        """

        assert sentence.dim() == 2, f"sentence dim: {sentence.dim()} 与 (batch_size, seq_len) 不匹配"
        assert category.dim() == 1, f"category dim: {category.dim()} 与 (batch_size,) 不匹配"

        bool_mask: BoolTensor = sequence_mask(sequence=sentence,
                                              padding_index=self._token_vocabulary.padding_index)
        long_mask = bool_mask.long()

        sentence_length = long_mask.sum(dim=-1)
        assert sentence_length.dim() == 1, f"sentence_length dim: {sentence_length.dim()} 与 (batch_size,) 不匹配"

        # sentence embedding, shape: (batch_size, seq_len, embedding_dim)
        sentence_embedding = self.token_embedding(sentence)

        assert sentence_embedding.dim() == 3, \
            f"sentence_embedding dim: {sentence_embedding.dim()} 与 (batch_size, seq_len, embedding_dim) 不匹配"

        # 对 category expand,(batch_size,) -> (batch_size, seq_len)
        # category.unsequeeze, (batch_size,) -> (batch_size, 1)
        category = category.unsqueeze(dim=1)
        # category.expand_as, (batch_size, 1) -> (batch_size, seq_len)
        category = category.expand_as(sentence)

        # category embedding, shape: (batch_size, seq_len, category_embedding_dim)
        category_embedding = self.category_embedding(category)
        assert category_embedding.dim() == 3, \
            f"category_embedding dim: {category_embedding.dim()} 与 (batch_size, seq_len, category_embedding_dim) 不匹配"

        # 将word embedding 与 category embedding 合并在一起
        input_embedding = torch.cat((category_embedding, sentence_embedding), dim=-1)

        # 使用 lstm sequence encoder 进行 encoder
        packed_sentence_embedding = pack_padded_sequence(input=input_embedding,
                                                         lengths=sentence_length,
                                                         batch_first=True,
                                                         enforce_sorted=False)

        packed_sequence, (h_n, c_n) = self.lstm(packed_sentence_embedding)

        # Tuple, sentence: shape: B * SeqLen * InputSize 和 sentence length
        (sentence_encoding, _) = pad_packed_sequence(packed_sequence, batch_first=True)

        # h_n shape (num_layers * num_directions, batch_size, hidden_size)
        h_n = torch.transpose(h_n, 0, 1)

        last_index = -2 if self.lstm.bidirectional else -1
        hidden_size = self.lstm.hidden_size * 2 if self.lstm.bidirectional else self.lstm.hidden_size

        # hn_last shape: (batch_size, hidden_size * (1 or 2))
        hn_last = h_n[:, last_index:, :].contiguous().view(-1, hidden_size)

        # 将 lstm 输出与 aspect embedding 合并在一起,准备做 attention
        # attention_inputs shape: (batch_size, seq_len, attention_dim = (lstm_hidden_size + category_dim))
        attention_inputs = torch.cat((sentence_encoding, category_embedding), dim=-1)

        # attention_seq_vec shape: (B,  lstm_hidden_size + category_dim)
        attention_seq_vec = self.attention_seq2vec(sequence=attention_inputs, mask=long_mask)

        # sentiment_vec shape: (B, lstm_hidden_size + category_dim + lstm_hidden_size)
        sentiment_vec = torch.cat((attention_seq_vec, hn_last), dim=-1)

        logits = self.fc(sentiment_vec)

        model_outputs = ACSAModelOutputs(logits=logits)

        return model_outputs
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            spans: torch.LongTensor,
            metadata: List[Dict[str, Any]],
            pos_tags: Dict[str, torch.LongTensor] = None,
            span_labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        spans : ``torch.LongTensor``, required.
            A tensor of shape ``(batch_size, num_spans, 2)`` representing the
            inclusive start and end indices of all possible spans in the sentence.
        metadata : List[Dict[str, Any]], required.
            A dictionary of metadata for each batch element which has keys:
                tokens : ``List[str]``, required.
                    The original string tokens in the sentence.
                gold_tree : ``nltk.Tree``, optional (default = None)
                    Gold NLTK trees for use in evaluation.
                pos_tags : ``List[str]``, optional.
                    The POS tags for the sentence. These can be used in the
                    model as embedded features, but they are passed here
                    in addition for use in constructing the tree.
        pos_tags : ``torch.LongTensor``, optional (default = None)
            The output of a ``SequenceLabelField`` containing POS tags.
        span_labels : ``torch.LongTensor``, optional (default = None)
            A torch tensor representing the integer gold class labels for all possible
            spans, of shape ``(batch_size, num_spans)``.
        Returns
        -------
        An output dictionary consisting of:
        class_probabilities : ``torch.FloatTensor``
            A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
            representing a distribution over the label classes per span.
        spans : ``torch.LongTensor``
            The original spans tensor.
        tokens : ``List[List[str]]``, required.
            A list of tokens in the sentence for each element in the batch.
        pos_tags : ``List[List[str]]``, required.
            A list of POS tags in the sentence for each element in the batch.
        num_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size), representing the lengths of non-padded spans
            in ``enumerated_spans``.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        embedded_text_input = self.text_field_embedder(tokens)
        if pos_tags is not None and self.pos_tag_embedding is not None:
            embedded_pos_tags = self.pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
        elif self.pos_tag_embedding is not None:
            raise ConfigurationError(
                "Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(tokens)
        # Looking at the span start index is enough to know if
        # this is padding or not. Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()
        if span_mask.dim() == 1:
            # This happens if you use batch_size 1 and encounter
            # a length 1 sentence in PTB, which do exist. -.-
            span_mask = span_mask.unsqueeze(-1)
        if span_labels is not None and span_labels.dim() == 1:
            span_labels = span_labels.unsqueeze(-1)

        num_spans = get_lengths_from_binary_sequence_mask(span_mask)

        encoded_text = self.encoder(embedded_text_input, mask)
        encoder_final_state = get_final_encoder_states(encoded_text, mask)

        output_dict = {
            "encoder_final_state": encoder_final_state,
            "encoded_text": encoded_text,
        }

        return output_dict
 def decode(self, tokens: torch.LongTensor):
     assert tokens.dim() == 1
     tokens = list(tokens.cpu().numpy())
     sentences = self.tokenizer.decode(tokens)
     return sentences