Beispiel #1
0
    def test_viterbi_decode(self):
        # Test Viterbi decoding is equal to greedy decoding with no pairwise potentials.
        sequence_predictions = torch.nn.functional.softmax(
            Variable(torch.rand([5, 9])))
        transition_matrix = torch.zeros([9, 9])
        indices, _ = viterbi_decode(sequence_predictions.data,
                                    transition_matrix)
        _, argmax_indices = torch.max(sequence_predictions, 1)
        assert indices == argmax_indices.data.squeeze().tolist()

        # Test that pairwise potentials effect the sequence correctly and that
        # viterbi_decode can handle -inf values.
        sequence_predictions = torch.FloatTensor([[0, 0, 0, 3, 4],
                                                  [0, 0, 0, 3, 4],
                                                  [0, 0, 0, 3, 4],
                                                  [0, 0, 0, 3, 4],
                                                  [0, 0, 0, 3, 4],
                                                  [0, 0, 0, 3, 4]])
        # The same tags shouldn't appear sequentially.
        transition_matrix = torch.zeros([5, 5])
        for i in range(5):
            transition_matrix[i, i] = float("-inf")
        indices, _ = viterbi_decode(sequence_predictions, transition_matrix)
        assert indices == [4, 3, 4, 3, 4, 3]

        # Test that unbalanced pairwise potentials break ties
        # between paths with equal unary potentials.
        sequence_predictions = torch.FloatTensor([[0, 0, 0, 4, 4],
                                                  [0, 0, 0, 4, 4],
                                                  [0, 0, 0, 4, 4],
                                                  [0, 0, 0, 4, 4],
                                                  [0, 0, 0, 4, 4],
                                                  [0, 0, 0, 4, 4]])
        # The 5th tag has a penalty for appearing sequentially
        # or for transitioning to the 4th tag, making the best
        # path uniquely to take the 4th tag only.
        transition_matrix = torch.zeros([5, 5])
        transition_matrix[4, 4] = -10
        transition_matrix[4, 3] = -10
        indices, _ = viterbi_decode(sequence_predictions, transition_matrix)
        assert indices == [3, 3, 3, 3, 3, 3]

        sequence_predictions = torch.FloatTensor([[1, 0, 0, 4], [1, 0, 6, 2],
                                                  [0, 3, 0, 4]])
        # Best path would normally be [3, 2, 3] but we add a
        # potential from 2 -> 1, making [3, 2, 1] the best path.
        transition_matrix = torch.zeros([4, 4])
        transition_matrix[0, 0] = 1
        transition_matrix[2, 1] = 5
        indices, value = viterbi_decode(sequence_predictions,
                                        transition_matrix)
        assert indices == [3, 2, 1]
        assert value.numpy() == 18
Beispiel #2
0
 def decode(
         self, output_dict: Dict[str,
                                 torch.Tensor]) -> Dict[str, torch.Tensor]:
     """
     Does constrained viterbi decoding on class probabilities output in :func:`forward`.  The
     constraint simply specifies that the output tags must be a valid BIO sequence.  We add a
     ``"tags"`` key to the dictionary with the result.
     """
     all_predictions = output_dict['class_probabilities']
     if isinstance(all_predictions, numpy.ndarray):
         all_predictions = torch.from_numpy(all_predictions)
     if all_predictions.dim() == 3:
         predictions_list = [
             all_predictions[i] for i in range(all_predictions.shape[0])
         ]
     else:
         predictions_list = [all_predictions]
     all_tags = []
     transition_matrix = self.get_viterbi_pairwise_potentials()
     for predictions in predictions_list:
         max_likelihood_sequence, _ = viterbi_decode(
             predictions, transition_matrix)
         tags = [
             self.vocab.get_token_from_index(x, namespace="labels")
             for x in max_likelihood_sequence
         ]
         all_tags.append(tags)
     if len(all_tags) == 1:
         all_tags = all_tags[0]  # type: ignore
     output_dict['tags'] = all_tags
     return output_dict
Beispiel #3
0
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        ph: Do NOT perform Viterbi decoding - we are interested in learning dynamics, not best performance
        """
        all_predictions = output_dict['class_probabilities']
        sequence_lengths = get_lengths_from_binary_sequence_mask(output_dict["mask"]).data.tolist()

        if all_predictions.dim() == 3:
            predictions_list = [all_predictions[i].detach().cpu() for i in range(all_predictions.size(0))]
        else:
            predictions_list = [all_predictions]
        all_tags = []

        # ph: transition matrices contain only ones (and no -inf, which would signal illegal transition)
        all_labels = self.vocab.get_index_to_token_vocabulary("labels")
        num_labels = len(all_labels)
        transition_matrix = torch.zeros([num_labels, num_labels])
        start_transitions = torch.zeros(num_labels)

        for predictions, length in zip(predictions_list, sequence_lengths):
            max_likelihood_sequence, _ = viterbi_decode(predictions[:length], transition_matrix,
                                                        allowed_start_transitions=start_transitions)
            tags = [self.vocab.get_token_from_index(x, namespace="labels")
                    for x in max_likelihood_sequence]
            all_tags.append(tags)
        output_dict['tags'] = all_tags
        return output_dict
Beispiel #4
0
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Does constrained viterbi decoding on class probabilities output in :func:`forward`.  The
        constraint simply specifies that the output tags must be a valid BIO sequence.  We add a
        ``"tags"`` key to the dictionary with the result.
        """
        all_predictions = output_dict['class_probabilities']
        sequence_lengths = get_lengths_from_binary_sequence_mask(
            output_dict["mask"]).data.tolist()

        if all_predictions.dim() == 3:
            predictions_list = [
                all_predictions[i].detach().cpu()
                for i in range(all_predictions.size(0))
            ]
        else:
            predictions_list = [all_predictions]
        all_tags = []
        transition_matrix = self.get_viterbi_pairwise_potentials()
        for predictions, length in zip(predictions_list, sequence_lengths):
            max_likelihood_sequence, _ = viterbi_decode(
                predictions[:length], transition_matrix)
            tags = [
                self.vocab.get_token_from_index(x, namespace="labels")
                for x in max_likelihood_sequence
            ]
            all_tags.append(tags)
        output_dict['tags'] = all_tags
        return output_dict
Beispiel #5
0
    def make_output_human_readable(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Does constrained viterbi decoding on class probabilities output in :func:`forward`.  The
        constraint simply specifies that the output tags must be a valid BIO sequence.  We add a
        `"tags"` key to the dictionary with the result.

        NOTE: First, we decode a BIO sequence on top of the wordpieces. This is important; viterbi
        decoding produces low quality output if you decode on top of word representations directly,
        because the model gets confused by the 'missing' positions (which is sensible as it is trained
        to perform tagging on wordpieces, not words).

        Secondly, it's important that the indices we use to recover words from the wordpieces are the
        start_offsets (i.e offsets which correspond to using the first wordpiece of words which are
        tokenized into multiple wordpieces) as otherwise, we might get an ill-formed BIO sequence
        when we select out the word tags from the wordpiece tags. This happens in the case that a word
        is split into multiple word pieces, and then we take the last tag of the word, which might
        correspond to, e.g, I-V, which would not be allowed as it is not preceeded by a B tag.
        """
        all_predictions = output_dict["class_probabilities"]
        sequence_lengths = get_lengths_from_binary_sequence_mask(
            output_dict["mask"]).data.tolist()

        if all_predictions.dim() == 3:
            predictions_list = [
                all_predictions[i].detach().cpu()
                for i in range(all_predictions.size(0))
            ]
        else:
            predictions_list = [all_predictions]
        wordpiece_tags = []
        word_tags = []
        transition_matrix = self.get_viterbi_pairwise_potentials()
        start_transitions = self.get_start_transitions()
        # **************** Different ********************
        # We add in the offsets here so we can compute the un-wordpieced tags.
        for predictions, length, offsets in zip(
                predictions_list, sequence_lengths,
                output_dict["wordpiece_offsets"]):
            max_likelihood_sequence, _ = viterbi_decode(
                predictions[:length],
                transition_matrix,
                allowed_start_transitions=start_transitions)
            tags = [
                self.vocab.get_token_from_index(
                    x, namespace=self._label_namespace)
                for x in max_likelihood_sequence
            ]

            wordpiece_tags.append(tags)
            if isinstance(self.bert_model,
                          PretrainedTransformerMismatchedEmbedder):
                word_tags.append(tags)
            else:
                word_tags.append([tags[i] for i in offsets])
            # print(word_tags)
        output_dict["wordpiece_tags"] = wordpiece_tags
        output_dict["tags"] = word_tags
        return output_dict
Beispiel #6
0
    def viterbi_tags(self, logits: torch.Tensor,
                     mask: torch.Tensor) -> List[Tuple[List[int], float]]:
        """
        Uses viterbi algorithm to find most likely tags for the given inputs.
        If constraints are applied, disallows all other transitions.
        """
        _, max_seq_length, num_tags = logits.size()

        # Get the tensors out of the variables
        logits, mask = logits.data, mask.data

        # Augment transitions matrix with start and end transitions
        start_tag = num_tags
        end_tag = num_tags + 1
        transitions = torch.Tensor(num_tags + 2, num_tags + 2).fill_(-10000.0)

        # Apply transition constraints
        constrained_transitions = self.transitions * self._constraint_mask[:num_tags, :num_tags] + -10000.0 * (
            1 - self._constraint_mask[:num_tags, :num_tags])
        transitions[:num_tags, :num_tags] = constrained_transitions.data

        if self.include_start_end_transitions:
            transitions[start_tag, :num_tags] = self.start_transitions.detach(
            ) * self._constraint_mask[start_tag, :num_tags].data + -10000.0 * (
                1 - self._constraint_mask[start_tag, :num_tags].detach())
            transitions[:num_tags, end_tag] = self.end_transitions.detach(
            ) * self._constraint_mask[:num_tags, end_tag].data + -10000.0 * (
                1 - self._constraint_mask[:num_tags, end_tag].detach())
        else:
            transitions[start_tag, :num_tags] = -10000.0 * (
                1 - self._constraint_mask[start_tag, :num_tags].detach())
            transitions[:num_tags, end_tag] = -10000.0 * (
                1 - self._constraint_mask[:num_tags, end_tag].detach())

        best_paths = []
        # Pad the max sequence length by 2 to account for start_tag + end_tag.
        tag_sequence = torch.Tensor(max_seq_length + 2, num_tags + 2)

        for prediction, prediction_mask in zip(logits, mask):
            sequence_length = torch.sum(prediction_mask)

            # Start with everything totally unlikely
            tag_sequence.fill_(-10000.0)
            # At timestep 0 we must have the START_TAG
            tag_sequence[0, start_tag] = 0.0
            # At steps 1, ..., sequence_length we just use the incoming prediction
            tag_sequence[1:(sequence_length +
                            1), :num_tags] = prediction[:sequence_length]
            # And at the last timestep we must have the END_TAG
            tag_sequence[sequence_length + 1, end_tag] = 0.0

            # We pass the tags and the transitions to ``viterbi_decode``.
            viterbi_path, viterbi_score = util.viterbi_decode(
                tag_sequence[:(sequence_length + 2)], transitions)
            # Get rid of START and END sentinels and append.
            viterbi_path = viterbi_path[1:-1]
            best_paths.append((viterbi_path, viterbi_score.item()))

        return best_paths
    def _get_gold_answer(self,
                         gold_answer_representations: Dict[str,
                                                           torch.LongTensor],
                         log_probs: torch.LongTensor,
                         mask: torch.LongTensor) -> torch.LongTensor:
        answer_as_text_to_disjoint_bios = gold_answer_representations[
            'answer_as_text_to_disjoint_bios']
        answer_as_list_of_bios = gold_answer_representations[
            'answer_as_list_of_bios']
        span_bio_labels = gold_answer_representations['span_bio_labels']

        with torch.no_grad():
            answer_as_list_of_bios = answer_as_list_of_bios * mask.unsqueeze(1)
            if answer_as_text_to_disjoint_bios.sum() > 0:
                # TODO: verify correctness (Elad)

                full_bio = span_bio_labels

                if self._generation_top_k > 0:
                    most_likely_predictions, _ = viterbi_decode(
                        log_probs.cpu(),
                        self._bio_allowed_transitions,
                        top_k=self._generation_top_k)
                    most_likely_predictions = torch.FloatTensor(
                        most_likely_predictions).to(log_probs.device)
                    # ^ Should be converted to tensor

                    most_likely_predictions = most_likely_predictions * mask.unsqueeze(
                        1)

                    generated_list_of_bios = self._filter_correct_predictions(
                        most_likely_predictions,
                        answer_as_text_to_disjoint_bios, full_bio)

                    is_pregenerated_answer_format_mask = (
                        answer_as_list_of_bios.sum(
                            (1, 2)) > 0).unsqueeze(-1).unsqueeze(-1).long()
                    bio_seqs = torch.cat(
                        (answer_as_list_of_bios,
                         (generated_list_of_bios *
                          (1 - is_pregenerated_answer_format_mask))),
                        dim=1)

                    bio_seqs = self._add_full_bio(bio_seqs, full_bio)
                else:
                    is_pregenerated_answer_format_mask = (
                        answer_as_list_of_bios.sum((1, 2)) > 0).long()
                    bio_seqs = torch.cat(
                        (answer_as_list_of_bios,
                         (full_bio * (1 - is_pregenerated_answer_format_mask
                                      ).unsqueeze(-1)).unsqueeze(1)),
                        dim=1)
            else:
                bio_seqs = answer_as_list_of_bios

        return bio_seqs
    def tag(self, text_field: TextField,
            verb_indicator: IndexField) -> Dict[str, Any]:
        """
        Perform inference on a ``Instance`` consisting of a single ``TextField`` representing
        the sentence and an ``IndexField`` representing an optional index into the sentence
        denoting a verbal predicate.

        Returned sequence is the maximum likelihood tag sequence under the constraint that
        the sequence must be a valid BIO sequence.

        Parameters
        ----------
        text_field : ``TextField``, required.
            A ``TextField`` containing the text to be tagged.
        verb_indicator: ``IndexField``, required.
            The index of the verb whose arguments we are labeling.

        Returns
        -------
        A Dict containing:

        tags : List[str]
            A list the length of the text input, containing the predicted (argmax) tag
            from the model per token.
        class_probabilities : numpy.Array
            An array of shape (text_input_length, num_classes), where each row is a
            distribution over classes for a given token in the sentence.
        """
        instance = Instance({
            "tokens": text_field,
            "verb_indicator": verb_indicator
        })
        instance.index_fields(self.vocab)
        model_input = arrays_to_variables(instance.as_array_dict(),
                                          add_batch_dimension=True,
                                          for_training=False)
        output_dict = self.forward(**model_input)

        # Remove batch dimension, as we only had one input.
        predictions = output_dict["class_probabilities"].data.squeeze(0)
        transition_matrix = self.get_viterbi_pairwise_potentials()

        max_likelihood_sequence, _ = viterbi_decode(predictions,
                                                    transition_matrix)
        tags = [
            self.vocab.get_token_from_index(x, namespace="tags")
            for x in max_likelihood_sequence
        ]

        return {"tags": tags, "class_probabilities": predictions.numpy()}
Beispiel #9
0
    def decode_tags(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        all_predictions = output_dict['class_probabilities']
        sequence_lengths = get_lengths_from_binary_sequence_mask(output_dict["mask"]).data.tolist()

        if all_predictions.dim() == 3:
            predictions_list = [all_predictions[i].data.cpu() for i in range(all_predictions.size(0))]
        else:
            predictions_list = [all_predictions]
        all_tags = []
        transition_matrix = self.get_viterbi_pairwise_potentials()
        for predictions, length in zip(predictions_list, sequence_lengths):
            max_likelihood_sequence, _ = viterbi_decode(predictions[:length], transition_matrix)
            tags = [self.vocab.get_token_from_index(x, namespace="labels")
                    for x in max_likelihood_sequence]
            all_tags.append(tags)
        output_dict['tags'] = all_tags
        return output_dict
    def viterbi_tags(self, logits: Variable, mask: Variable) -> List[List[int]]:
        """
        Uses viterbi algorithm to find most likely tags for the given inputs.
        If constraints are applied, disallows all other transitions.
        """
        _, max_seq_length, num_tags = logits.size()

        # Get the tensors out of the variables
        logits, mask = logits.data, mask.data

        # Augment transitions matrix with start and end transitions
        start_tag = num_tags
        end_tag = num_tags + 1
        transitions = torch.Tensor(num_tags + 2, num_tags + 2).fill_(-10000.)

        # Apply transition constraints
        constrained_transitions = (self.transitions * self._constraint_mask +
                                   -10000.0 * (1 - self._constraint_mask))

        transitions[:num_tags, :num_tags] = constrained_transitions.data
        transitions[start_tag, :num_tags] = self.start_transitions.data
        transitions[:num_tags, end_tag] = self.end_transitions.data

        all_tags = []
        # Pad the max sequence length by 2 to account for start_tag + end_tag.
        tag_sequence = torch.Tensor(max_seq_length + 2, num_tags + 2)

        for prediction, prediction_mask in zip(logits, mask):
            sequence_length = torch.sum(prediction_mask)

            # Start with everything totally unlikely
            tag_sequence.fill_(-10000.)
            # At timestep 0 we must have the START_TAG
            tag_sequence[0, start_tag] = 0.
            # At steps 1, ..., sequence_length we just use the incoming prediction
            tag_sequence[1:(sequence_length + 1), :num_tags] = prediction[:sequence_length]
            # And at the last timestep we must have the END_TAG
            tag_sequence[sequence_length + 1, end_tag] = 0.

            # We pass the tags and the transitions to ``viterbi_decode``.
            viterbi_path, _ = util.viterbi_decode(tag_sequence[:(sequence_length + 2)], transitions)
            # Get rid of START and END sentinels and append.
            all_tags.append(viterbi_path[1:-1])

        return all_tags
    def viterbi_tags(self, logits: Variable,
                     mask: Variable) -> List[List[int]]:
        """
        Uses viterbi algorithm to find most likely tags for the given inputs.
        """
        _, max_seq_length, num_tags = logits.size()

        # Get the tensors out of the variables
        logits, mask = logits.data, mask.data

        # Augment transitions matrix with start and end transitions
        start_tag = num_tags
        end_tag = num_tags + 1
        transitions = torch.Tensor(num_tags + 2, num_tags + 2).fill_(-10000.)

        transitions[:num_tags, :num_tags] = self.transitions.data
        transitions[start_tag, :num_tags] = self.start_transitions.data
        transitions[:num_tags, end_tag] = self.end_transitions.data

        all_tags = []
        # Pad the max sequence length by 2 to account for start_tag + end_tag.
        tag_sequence = torch.Tensor(max_seq_length + 2, num_tags + 2)

        for prediction, prediction_mask in zip(logits, mask):
            sequence_length = torch.sum(prediction_mask)

            # Start with everything totally unlikely
            tag_sequence.fill_(-10000.)
            # At timestep 0 we must have the START_TAG
            tag_sequence[0, start_tag] = 0.
            # At steps 1, ..., sequence_length we just use the incoming prediction
            tag_sequence[1:(sequence_length +
                            1), :num_tags] = prediction[:sequence_length]
            # And at the last timestep we must have the END_TAG
            tag_sequence[sequence_length + 1, end_tag] = 0.

            # We pass the tags and the transitions to ``viterbi_decode``.
            viterbi_path, _ = util.viterbi_decode(
                tag_sequence[:(sequence_length + 2)], transitions)
            # Get rid of START and END sentinels and append.
            all_tags.append(viterbi_path[1:-1])

        return all_tags
Beispiel #12
0
def decode(transition_matrix, vocab,
           output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """
    Does constrained viterbi decoding on class probabilities output in :func:`forward`.  The
    constraint simply specifies that the output tags must be a valid BIO sequence.  We add a
    ``"tags"`` key to the dictionary with the result.
    """
    all_predictions = output_dict['class_probabilities']
    sequence_lengths = get_lengths_from_binary_sequence_mask(
        output_dict["mask"]).data.tolist()
    max_seq_length = max(sequence_lengths)
    if all_predictions.dim() == 3:
        predictions_list = [
            all_predictions[i].detach().cpu()
            for i in range(all_predictions.size(0))
        ]
    else:
        predictions_list = [all_predictions]
    all_tags = []
    all_prob_seq = []
    num_tags = transition_matrix.shape[0] - 2
    tag_sequence = torch.Tensor(max_seq_length + 2, num_tags + 2)
    for predictions, length in zip(predictions_list, sequence_lengths):
        tag_sequence.fill_(-10000.)
        tag_sequence[0, num_tags] = 0.
        tag_sequence[1:(length + 1), :num_tags] = predictions[:length]
        tag_sequence[(length + 1), num_tags + 1] = 0.

        max_likelihood_sequence, _ = viterbi_decode(
            tag_sequence[:(length + 2)], transition_matrix)
        max_likelihood_sequence = max_likelihood_sequence[1:-1]
        tags = [
            vocab.get_token_from_index(x, namespace="labels")
            for x in max_likelihood_sequence
        ]
        all_tags.append(tags)
        all_prob_seq.append(max_likelihood_sequence)
    output_dict['tags'] = all_tags
    output_dict['max_seq'] = all_prob_seq
    return output_dict
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Does constrained viterbi decoding on class probabilities output in :func:`forward`.  The
        constraint simply specifies that the output tags must be a valid BIO sequence.  We add a
        ``"tags"`` key to the dictionary with the result.
        """
        all_predictions = output_dict['class_probabilities']
        sequence_lengths = get_lengths_from_binary_sequence_mask(output_dict["mask"]).data.tolist()

        if all_predictions.dim() == 3:
            predictions_list = [all_predictions[i].detach().cpu() for i in range(all_predictions.size(0))]
        else:
            predictions_list = [all_predictions]
        all_tags = []
        transition_matrix = self.get_viterbi_pairwise_potentials()
        for predictions, length in zip(predictions_list, sequence_lengths):
            max_likelihood_sequence, _ = viterbi_decode(predictions[:length], transition_matrix)
            tags = [self.vocab.get_token_from_index(x, namespace="labels")
                    for x in max_likelihood_sequence]
            all_tags.append(tags)
        output_dict['tags'] = all_tags
        return output_dict
    def viterbi_tags(
        self,
        logits: torch.Tensor,
        mask: torch.BoolTensor = None,
        top_k: int = None
    ) -> Union[List[VITERBI_DECODING], List[List[VITERBI_DECODING]]]:
        """
        Uses viterbi algorithm to find most likely tags for the given inputs.
        If constraints are applied, disallows all other transitions.

        Returns a list of results, of the same size as the batch (one result per batch member)
        Each result is a List of length top_k, containing the top K viterbi decodings
        Each decoding is a tuple  (tag_sequence, viterbi_score)

        For backwards compatibility, if top_k is None, then instead returns a flat list of
        tag sequences (the top tag sequence for each batch item).
        """
        if mask is None:
            mask = torch.ones(*logits.shape[:2],
                              dtype=torch.bool,
                              device=logits.device)

        if top_k is None:
            top_k = 1
            flatten_output = True
        else:
            flatten_output = False

        _, max_seq_length, num_tags = logits.size()

        # Get the tensors out of the variables
        logits, mask = logits.data, mask.data

        # Augment transitions matrix with start and end transitions
        start_tag = num_tags
        end_tag = num_tags + 1
        transitions = torch.Tensor(num_tags + 2, num_tags + 2).fill_(-10000.0)

        # Apply transition constraints
        constrained_transitions = self.transitions * self._constraint_mask[:num_tags, :num_tags] + -10000.0 * (
            1 - self._constraint_mask[:num_tags, :num_tags])
        transitions[:num_tags, :num_tags] = constrained_transitions.data

        if self.include_start_end_transitions:
            transitions[start_tag, :num_tags] = self.start_transitions.detach(
            ) * self._constraint_mask[start_tag, :num_tags].data + -10000.0 * (
                1 - self._constraint_mask[start_tag, :num_tags].detach())
            transitions[:num_tags, end_tag] = self.end_transitions.detach(
            ) * self._constraint_mask[:num_tags, end_tag].data + -10000.0 * (
                1 - self._constraint_mask[:num_tags, end_tag].detach())
        else:
            transitions[start_tag, :num_tags] = -10000.0 * (
                1 - self._constraint_mask[start_tag, :num_tags].detach())
            transitions[:num_tags, end_tag] = -10000.0 * (
                1 - self._constraint_mask[:num_tags, end_tag].detach())

        best_paths = []
        # Pad the max sequence length by 2 to account for start_tag + end_tag.
        tag_sequence = torch.Tensor(max_seq_length + 2, num_tags + 2)

        for prediction, prediction_mask in zip(logits, mask):
            mask_indices = prediction_mask.nonzero().squeeze()
            masked_prediction = torch.index_select(prediction, 0, mask_indices)
            sequence_length = masked_prediction.shape[0]

            # Start with everything totally unlikely
            tag_sequence.fill_(-10000.0)
            # At timestep 0 we must have the START_TAG
            tag_sequence[0, start_tag] = 0.0
            # At steps 1, ..., sequence_length we just use the incoming prediction
            tag_sequence[1:(sequence_length +
                            1), :num_tags] = masked_prediction
            # And at the last timestep we must have the END_TAG
            tag_sequence[sequence_length + 1, end_tag] = 0.0

            # We pass the tags and the transitions to `viterbi_decode`.
            viterbi_paths, viterbi_scores = util.viterbi_decode(
                tag_sequence=tag_sequence[:(sequence_length + 2)],
                transition_matrix=transitions,
                top_k=top_k,
            )
            top_k_paths = []
            for viterbi_path, viterbi_score in zip(viterbi_paths,
                                                   viterbi_scores):
                # Get rid of START and END sentinels and append.
                viterbi_path = viterbi_path[1:-1]
                top_k_paths.append((viterbi_path, viterbi_score.item()))
            best_paths.append(top_k_paths)

        if flatten_output:
            return [top_k_paths[0] for top_k_paths in best_paths]

        return best_paths
Beispiel #15
0
    def test_viterbi_decode(self):
        # Test Viterbi decoding is equal to greedy decoding with no pairwise potentials.
        sequence_logits = torch.nn.functional.softmax(torch.rand([5, 9]), dim=-1)
        transition_matrix = torch.zeros([9, 9])
        indices, _ = util.viterbi_decode(sequence_logits.data, transition_matrix)
        _, argmax_indices = torch.max(sequence_logits, 1)
        assert indices == argmax_indices.data.squeeze().tolist()

        # Test that pairwise potentials effect the sequence correctly and that
        # viterbi_decode can handle -inf values.
        sequence_logits = torch.FloatTensor([[0, 0, 0, 3, 4],
                                             [0, 0, 0, 3, 4],
                                             [0, 0, 0, 3, 4],
                                             [0, 0, 0, 3, 4],
                                             [0, 0, 0, 3, 4],
                                             [0, 0, 0, 3, 4]])
        # The same tags shouldn't appear sequentially.
        transition_matrix = torch.zeros([5, 5])
        for i in range(5):
            transition_matrix[i, i] = float("-inf")
        indices, _ = util.viterbi_decode(sequence_logits, transition_matrix)
        assert indices == [4, 3, 4, 3, 4, 3]

        # Test that unbalanced pairwise potentials break ties
        # between paths with equal unary potentials.
        sequence_logits = torch.FloatTensor([[0, 0, 0, 4, 4],
                                             [0, 0, 0, 4, 4],
                                             [0, 0, 0, 4, 4],
                                             [0, 0, 0, 4, 4],
                                             [0, 0, 0, 4, 4],
                                             [0, 0, 0, 4, 4]])
        # The 5th tag has a penalty for appearing sequentially
        # or for transitioning to the 4th tag, making the best
        # path uniquely to take the 4th tag only.
        transition_matrix = torch.zeros([5, 5])
        transition_matrix[4, 4] = -10
        transition_matrix[4, 3] = -10
        indices, _ = util.viterbi_decode(sequence_logits, transition_matrix)
        assert indices == [3, 3, 3, 3, 3, 3]

        sequence_logits = torch.FloatTensor([[1, 0, 0, 4],
                                             [1, 0, 6, 2],
                                             [0, 3, 0, 4]])
        # Best path would normally be [3, 2, 3] but we add a
        # potential from 2 -> 1, making [3, 2, 1] the best path.
        transition_matrix = torch.zeros([4, 4])
        transition_matrix[0, 0] = 1
        transition_matrix[2, 1] = 5
        indices, value = util.viterbi_decode(sequence_logits, transition_matrix)
        assert indices == [3, 2, 1]
        assert value.numpy() == 18

        # Test that providing evidence results in paths containing specified tags.
        sequence_logits = torch.FloatTensor([[0, 0, 0, 7, 7],
                                             [0, 0, 0, 7, 7],
                                             [0, 0, 0, 7, 7],
                                             [0, 0, 0, 7, 7],
                                             [0, 0, 0, 7, 7],
                                             [0, 0, 0, 7, 7]])
        # The 5th tag has a penalty for appearing sequentially
        # or for transitioning to the 4th tag, making the best
        # path to take the 4th tag for every label.
        transition_matrix = torch.zeros([5, 5])
        transition_matrix[4, 4] = -10
        transition_matrix[4, 3] = -2
        # The 1st, 4th and 5th sequence elements are observed - they should be
        # equal to 2, 0 and 4. The last tag should be equal to 3, because although
        # the penalty for transitioning to the 4th tag is -2, the unary potential
        # is 7, which is greater than the combination for any of the other labels.
        observations = [2, -1, -1, 0, 4, -1]
        indices, _ = util.viterbi_decode(sequence_logits,
                                         transition_matrix,
                                         observations)
        assert indices == [2, 3, 3, 0, 4, 3]
Beispiel #16
0
    def test_viterbi_decode(self):
        # Test Viterbi decoding is equal to greedy decoding with no pairwise potentials.
        sequence_logits = torch.nn.functional.softmax(torch.rand([5, 9]), dim=-1)
        transition_matrix = torch.zeros([9, 9])
        indices, _ = util.viterbi_decode(sequence_logits.data, transition_matrix)
        _, argmax_indices = torch.max(sequence_logits, 1)
        assert indices == argmax_indices.data.squeeze().tolist()

        # Test that pairwise potentials effect the sequence correctly and that
        # viterbi_decode can handle -inf values.
        sequence_logits = torch.FloatTensor([[0, 0, 0, 3, 4],
                                             [0, 0, 0, 3, 4],
                                             [0, 0, 0, 3, 4],
                                             [0, 0, 0, 3, 4],
                                             [0, 0, 0, 3, 4],
                                             [0, 0, 0, 3, 4]])
        # The same tags shouldn't appear sequentially.
        transition_matrix = torch.zeros([5, 5])
        for i in range(5):
            transition_matrix[i, i] = float("-inf")
        indices, _ = util.viterbi_decode(sequence_logits, transition_matrix)
        assert indices == [4, 3, 4, 3, 4, 3]

        # Test that unbalanced pairwise potentials break ties
        # between paths with equal unary potentials.
        sequence_logits = torch.FloatTensor([[0, 0, 0, 4, 4],
                                             [0, 0, 0, 4, 4],
                                             [0, 0, 0, 4, 4],
                                             [0, 0, 0, 4, 4],
                                             [0, 0, 0, 4, 4],
                                             [0, 0, 0, 4, 4]])
        # The 5th tag has a penalty for appearing sequentially
        # or for transitioning to the 4th tag, making the best
        # path uniquely to take the 4th tag only.
        transition_matrix = torch.zeros([5, 5])
        transition_matrix[4, 4] = -10
        transition_matrix[4, 3] = -10
        indices, _ = util.viterbi_decode(sequence_logits, transition_matrix)
        assert indices == [3, 3, 3, 3, 3, 3]

        sequence_logits = torch.FloatTensor([[1, 0, 0, 4],
                                             [1, 0, 6, 2],
                                             [0, 3, 0, 4]])
        # Best path would normally be [3, 2, 3] but we add a
        # potential from 2 -> 1, making [3, 2, 1] the best path.
        transition_matrix = torch.zeros([4, 4])
        transition_matrix[0, 0] = 1
        transition_matrix[2, 1] = 5
        indices, value = util.viterbi_decode(sequence_logits, transition_matrix)
        assert indices == [3, 2, 1]
        assert value.numpy() == 18

        # Test that providing evidence results in paths containing specified tags.
        sequence_logits = torch.FloatTensor([[0, 0, 0, 7, 7],
                                             [0, 0, 0, 7, 7],
                                             [0, 0, 0, 7, 7],
                                             [0, 0, 0, 7, 7],
                                             [0, 0, 0, 7, 7],
                                             [0, 0, 0, 7, 7]])
        # The 5th tag has a penalty for appearing sequentially
        # or for transitioning to the 4th tag, making the best
        # path to take the 4th tag for every label.
        transition_matrix = torch.zeros([5, 5])
        transition_matrix[4, 4] = -10
        transition_matrix[4, 3] = -2
        # The 1st, 4th and 5th sequence elements are observed - they should be
        # equal to 2, 0 and 4. The last tag should be equal to 3, because although
        # the penalty for transitioning to the 4th tag is -2, the unary potential
        # is 7, which is greater than the combination for any of the other labels.
        observations = [2, -1, -1, 0, 4, -1]
        indices, _ = util.viterbi_decode(sequence_logits,
                                         transition_matrix,
                                         observations)
        assert indices == [2, 3, 3, 0, 4, 3]
    def viterbi_tags(
            self,
            logits: torch.Tensor,
            mask: torch.Tensor,
            constraint_mask: torch.Tensor = None
    ) -> List[Tuple[List[int], float]]:
        """
        Uses viterbi algorithm to find most likely tags for the given inputs.
        If constraints are applied, disallows all other transitions.

        Parameters
        ----------
        logits: torch.Tensor
            Shape: (batch_size, max_seq_length, num_tags) Tensor of logits.
        mask: torch.Tensor
            Shape: (batch_size, max_seq_length, num_tags) Tensor of logits.
        constraint_mask: torch.Tensor, optional (default=None)
            Shape: (batch_size, num_tags+2, num_tags+2) Tensor of the allowed
            transitions for each example in the batch.
        """
        # pylint: disable=arguments-differ
        if constraint_mask is None:
            # Defer to superclass function if there is no custom constraint mask.
            return super().viterbi_tags(logits=logits, mask=mask)
        # We have a custom constraint mask for each example, so we need to re-mask
        # when we make each prediction.
        batch_size, max_seq_length, num_tags = logits.size()

        assert list(constraint_mask.size()) == [
            batch_size, num_tags + 2, num_tags + 2
        ]

        # Get the tensors out of the variables
        logits, mask = logits.data, mask.data

        start_tag = num_tags
        end_tag = num_tags + 1
        best_paths = []
        # Pad the max sequence length by 2 to account for start_tag + end_tag.
        tag_sequence = torch.Tensor(max_seq_length + 2, num_tags + 2)

        for prediction, prediction_mask, prediction_constraint_mask in zip(
                logits, mask, constraint_mask):
            prediction_constraint_mask = torch.nn.Parameter(
                prediction_constraint_mask, requires_grad=False)
            # Augment transitions matrix with start and end transitions
            transitions = torch.Tensor(num_tags + 2,
                                       num_tags + 2).fill_(-10000.)
            # Apply transition constraints
            constrained_transitions = (
                self.transitions *
                prediction_constraint_mask[:num_tags, :num_tags] + -10000.0 *
                (1 - prediction_constraint_mask[:num_tags, :num_tags]))
            transitions[:num_tags, :num_tags] = constrained_transitions.data

            if self.include_start_end_transitions:
                transitions[start_tag, :num_tags] = (
                    self.start_transitions.detach() *
                    prediction_constraint_mask[start_tag, :num_tags].data +
                    -10000.0 *
                    (1 -
                     prediction_constraint_mask[start_tag, :num_tags].detach())
                )
                transitions[:num_tags, end_tag] = (
                    self.end_transitions.detach() *
                    prediction_constraint_mask[:num_tags, end_tag].data +
                    -10000.0 *
                    (1 -
                     prediction_constraint_mask[:num_tags, end_tag].detach()))
            else:
                transitions[start_tag, :num_tags] = (-10000.0 * (
                    1 -
                    prediction_constraint_mask[start_tag, :num_tags].detach()))
                transitions[:num_tags, end_tag] = (
                    -10000.0 *
                    (1 -
                     prediction_constraint_mask[:num_tags, end_tag].detach()))

            sequence_length = torch.sum(prediction_mask)

            # Start with everything totally unlikely
            tag_sequence.fill_(-10000.)
            # At timestep 0 we must have the START_TAG
            tag_sequence[0, start_tag] = 0.
            # At steps 1, ..., sequence_length we just use the incoming prediction
            tag_sequence[1:(sequence_length +
                            1), :num_tags] = prediction[:sequence_length]
            # And at the last timestep we must have the END_TAG
            tag_sequence[sequence_length + 1, end_tag] = 0.

            # We pass the tags and the transitions to ``viterbi_decode``.
            viterbi_path, viterbi_score = util.viterbi_decode(
                tag_sequence[:(sequence_length + 2)], transitions)
            # Get rid of START and END sentinels and append.
            viterbi_path = viterbi_path[1:-1]
            best_paths.append((viterbi_path, viterbi_score.item()))
        return best_paths