Esempio n. 1
0
    def test_iob1_tags_to_spans_extracts_correct_spans(self):
        tag_sequence = [
            u"I-ARG2", u"B-ARG1", u"I-ARG1", u"O", u"B-ARG2", u"I-ARG2",
            u"B-ARG1", u"B-ARG2"
        ]
        spans = span_utils.iob1_tags_to_spans(tag_sequence)
        assert set(spans) == set([(u"ARG2", (0, 0)), (u"ARG1", (1, 2)),
                                  (u"ARG2", (4, 5)), (u"ARG1", (6, 6)),
                                  (u"ARG2", (7, 7))])

        # Check that it raises when we use U- tags for single tokens.
        tag_sequence = [
            u"O", u"B-ARG1", u"I-ARG1", u"O", u"B-ARG2", u"I-ARG2", u"U-ARG1",
            u"U-ARG2"
        ]
        with self.assertRaises(span_utils.InvalidTagSequence):
            spans = span_utils.iob1_tags_to_spans(tag_sequence)

        # Check that invalid IOB1 sequences are also handled as spans.
        tag_sequence = [
            u"O", u"B-ARG1", u"I-ARG1", u"O", u"I-ARG1", u"B-ARG2", u"I-ARG2",
            u"B-ARG1", u"I-ARG2", u"I-ARG2"
        ]
        spans = span_utils.iob1_tags_to_spans(tag_sequence)
        assert set(spans) == set([(u"ARG1", (1, 2)), (u"ARG1", (4, 4)),
                                  (u"ARG2", (5, 6)), (u"ARG1", (7, 7)),
                                  (u"ARG2", (8, 9))])
Esempio n. 2
0
    def test_iob1_tags_to_spans_extracts_correct_spans_without_labels(self):
        tag_sequence = ["I", "B", "I", "O", "B", "I", "B", "B"]
        spans = span_utils.iob1_tags_to_spans(tag_sequence)
        assert set(spans) == {("", (0, 0)), ("", (1, 2)), ("", (4, 5)), ("", (6, 6)), ("", (7, 7))}

        # Check that it raises when we use U- tags for single tokens.
        tag_sequence = ["O", "B", "I", "O", "B", "I", "U", "U"]
        with self.assertRaises(span_utils.InvalidTagSequence):
            spans = span_utils.iob1_tags_to_spans(tag_sequence)

        # Check that invalid IOB1 sequences are also handled as spans.
        tag_sequence = ["O", "B", "I", "O", "I", "B", "I", "B", "I", "I"]
        spans = span_utils.iob1_tags_to_spans(tag_sequence)
        assert set(spans) == {('', (1, 2)), ('', (4, 4)), ('', (5, 6)), ('', (7, 9))}
Esempio n. 3
0
    def test_iob1_tags_to_spans_extracts_correct_spans_without_labels(self):
        tag_sequence = ["I", "B", "I", "O", "B", "I", "B", "B"]
        spans = span_utils.iob1_tags_to_spans(tag_sequence)
        assert set(spans) == {("", (0, 0)), ("", (1, 2)), ("", (4, 5)), ("", (6, 6)), ("", (7, 7))}

        # Check that it raises when we use U- tags for single tokens.
        tag_sequence = ["O", "B", "I", "O", "B", "I", "U", "U"]
        with self.assertRaises(span_utils.InvalidTagSequence):
            spans = span_utils.iob1_tags_to_spans(tag_sequence)

        # Check that invalid IOB1 sequences are also handled as spans.
        tag_sequence = ["O", "B", "I", "O", "I", "B", "I", "B", "I", "I"]
        spans = span_utils.iob1_tags_to_spans(tag_sequence)
        assert set(spans) == {('', (1, 2)), ('', (4, 4)), ('', (5, 6)), ('', (7, 9))}
Esempio n. 4
0
    def __call__(self,
                 predictions: torch.Tensor,
                 gold_labels: torch.Tensor,
                 mask: Optional[torch.Tensor] = None,
                 prediction_map: Optional[torch.Tensor] = None):
        """
        Parameters
        ----------
        predictions : ``torch.Tensor``, required.
            A tensor of predictions of shape (batch_size, sequence_length, num_classes).
        gold_labels : ``torch.Tensor``, required.
            A tensor of integer class label of shape (batch_size, sequence_length). It must be the same
            shape as the ``predictions`` tensor without the ``num_classes`` dimension.
        mask: ``torch.Tensor``, optional (default = None).
            A masking tensor the same size as ``gold_labels``.
        prediction_map: ``torch.Tensor``, optional (default = None).
            A tensor of size (batch_size, num_classes) which provides a mapping from the index of predictions
            to the indices of the label vocabulary. If provided, the output label at each timestep will be
            ``vocabulary.get_index_to_token_vocabulary(prediction_map[batch, argmax(predictions[batch, t]))``,
            rather than simply ``vocabulary.get_index_to_token_vocabulary(argmax(predictions[batch, t]))``.
            This is useful in cases where each Instance in the dataset is associated with a different possible
            subset of labels from a large label-space (IE FrameNet, where each frame has a different set of
            possible roles associated with it).
        """
        if mask is None:
            mask = torch.ones_like(gold_labels)

        predictions, gold_labels, mask, prediction_map = self.unwrap_to_tensors(
            predictions, gold_labels, mask, prediction_map)

        num_classes = predictions.size(-1)
        if (gold_labels >= num_classes).any():
            raise ConfigurationError(
                "A gold label passed to SpanBasedF1Measure contains an "
                "id >= {}, the number of classes.".format(num_classes))

        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        argmax_predictions = predictions.max(-1)[1]

        if prediction_map is not None:
            argmax_predictions = torch.gather(prediction_map, 1,
                                              argmax_predictions)
            gold_labels = torch.gather(prediction_map, 1, gold_labels.long())

        argmax_predictions = argmax_predictions.float()

        # Iterate over timesteps in batch.
        batch_size = gold_labels.size(0)
        for i in range(batch_size):
            sequence_prediction = argmax_predictions[i, :]
            sequence_gold_label = gold_labels[i, :]
            length = sequence_lengths[i]

            if length == 0:
                # It is possible to call this metric with sequences which are
                # completely padded. These contribute nothing, so we skip these rows.
                continue

            predicted_string_labels = [
                self._label_vocabulary[label_id]
                for label_id in sequence_prediction[:length].tolist()
            ]
            gold_string_labels = [
                self._label_vocabulary[label_id]
                for label_id in sequence_gold_label[:length].tolist()
            ]

            if self._label_encoding == "BIO":
                predicted_spans = bio_tags_to_spans(predicted_string_labels,
                                                    self._ignore_classes)
                gold_spans = bio_tags_to_spans(gold_string_labels,
                                               self._ignore_classes)
            elif self._label_encoding == "IOB1":
                predicted_spans = iob1_tags_to_spans(predicted_string_labels,
                                                     self._ignore_classes)
                gold_spans = iob1_tags_to_spans(gold_string_labels,
                                                self._ignore_classes)
            elif self._label_encoding == "BIOUL":
                predicted_spans = bioul_tags_to_spans(predicted_string_labels,
                                                      self._ignore_classes)
                gold_spans = bioul_tags_to_spans(gold_string_labels,
                                                 self._ignore_classes)

            predicted_spans = self._handle_continued_spans(predicted_spans)
            gold_spans = self._handle_continued_spans(gold_spans)

            for span in predicted_spans:
                if span in gold_spans:
                    self._true_positives[span[0]] += 1
                    gold_spans.remove(span)
                else:
                    self._false_positives[span[0]] += 1
            # These spans weren't predicted.
            for span in gold_spans:
                self._false_negatives[span[0]] += 1