def test_span_metrics_are_computed_correctly(self):
        from allennlp_models.structured_prediction.models.srl import (
            convert_bio_tags_to_conll_format, )

        batch_verb_indices = [2]
        batch_sentences = [["The", "cat", "loves", "hats", "."]]
        batch_bio_predicted_tags = [["B-ARG0", "B-ARG1", "B-V", "B-ARG1", "O"]]
        batch_conll_predicted_tags = [
            convert_bio_tags_to_conll_format(tags)
            for tags in batch_bio_predicted_tags
        ]
        batch_bio_gold_tags = [["B-ARG0", "I-ARG0", "B-V", "B-ARG1", "O"]]
        batch_conll_gold_tags = [
            convert_bio_tags_to_conll_format(tags)
            for tags in batch_bio_gold_tags
        ]

        srl_scorer = SrlEvalScorer(ignore_classes=["V"])
        srl_scorer(batch_verb_indices, batch_sentences,
                   batch_conll_predicted_tags, batch_conll_gold_tags)
        metrics = srl_scorer.get_metric()
        assert len(metrics) == 9
        assert_allclose(metrics["precision-ARG0"], 0.0)
        assert_allclose(metrics["recall-ARG0"], 0.0)
        assert_allclose(metrics["f1-measure-ARG0"], 0.0)
        assert_allclose(metrics["precision-ARG1"], 0.5)
        assert_allclose(metrics["recall-ARG1"], 1.0)
        assert_allclose(metrics["f1-measure-ARG1"], 2 / 3)
        assert_allclose(metrics["precision-overall"], 1 / 3)
        assert_allclose(metrics["recall-overall"], 1 / 2)
        assert_allclose(metrics["f1-measure-overall"],
                        (2 * (1 / 3) * (1 / 2)) / ((1 / 3) + (1 / 2)))
    def test_bio_tags_correctly_convert_to_conll_format(self):
        bio_tags = ["B-ARG-1", "I-ARG-1", "O", "B-V", "B-ARGM-ADJ", "O"]
        from allennlp_models.structured_prediction.models.srl import (
            convert_bio_tags_to_conll_format, )

        conll_tags = convert_bio_tags_to_conll_format(bio_tags)
        assert conll_tags == ["(ARG-1*", "*)", "*", "(V*)", "(ARGM-ADJ*)", "*"]
    def test_distributed_setting_throws_an_error(self):
        from allennlp_models.structured_prediction.models.srl import (
            convert_bio_tags_to_conll_format, )

        batch_verb_indices = [2]
        batch_sentences = [["The", "cat", "loves", "hats", "."]]
        batch_bio_predicted_tags = [["B-ARG0", "B-ARG1", "B-V", "B-ARG1", "O"]]
        batch_conll_predicted_tags = [
            convert_bio_tags_to_conll_format(tags)
            for tags in batch_bio_predicted_tags
        ]
        batch_bio_gold_tags = [["B-ARG0", "I-ARG0", "B-V", "B-ARG1", "O"]]
        batch_conll_gold_tags = [
            convert_bio_tags_to_conll_format(tags)
            for tags in batch_bio_gold_tags
        ]

        metric_kwargs = {
            "batch_verb_indices": [batch_verb_indices, batch_verb_indices],
            "batch_sentences": [batch_sentences, batch_sentences],
            "batch_conll_formatted_predicted_tags": [
                batch_conll_predicted_tags,
                batch_conll_predicted_tags,
            ],
            "batch_conll_formatted_gold_tags":
            [batch_conll_gold_tags, batch_conll_gold_tags],
        }

        desired_values = {}  # it does not matter, we expect the run to fail.

        with pytest.raises(Exception) as exc:
            run_distributed_test(
                [-1, -1],
                global_distributed_metric,
                SrlEvalScorer(ignore_classes=["V"]),
                metric_kwargs,
                desired_values,
                exact=True,
            )
            assert (
                "RuntimeError: Distributed aggregation for `SrlEvalScorer` is currently not supported."
                in str(exc.value))
Beispiel #4
0
    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        verb_indicator: torch.Tensor,
        frame_indicator: torch.Tensor,
        metadata: List[Any],
        tags: torch.LongTensor = None,
        frame_tags: torch.LongTensor = None,
    ):
        """
        # Parameters

        tokens : `TextFieldTensors`, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which
            indexes wordpieces from the BERT vocabulary.
        verb_indicator: `torch.LongTensor`, required.
            An integer `SequenceFeatureField` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        frame_indicator: torch.LongTensor, required.
            An integer ``SequenceFeatureField`` representation of the position of the frame
            in the sentence. This should have shape (batch_size, num_tokens). Similar to verb_indicator,
            but handles bert wordpiece tokenizer by cosnidering a frame only the first subtoken.
        tags : `torch.LongTensor`, optional (default = `None`)
            A torch tensor representing the sequence of integer gold class labels
            of shape `(batch_size, num_tokens)`
        frame_tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the gold frames
            of shape ``(batch_size, num_tokens)``
        metadata : `List[Dict[str, Any]]`, optional, (default = `None`)
            metadata containg the original words in the sentence, the verb to compute the
            frame for, and start offsets for converting wordpieces back to a sequence of words,
            under 'words', 'verb' and 'offsets' keys, respectively.

        # Returns

        An output dictionary consisting of:
        logits : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            a distribution of the tag classes per word.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """
        mask = get_text_field_mask(tokens)
        input_ids = util.get_token_ids_from_text_field_tensors(tokens)
        bert_embeddings, _ = self.transformer(
            input_ids=input_ids,
            token_type_ids=verb_indicator,
            attention_mask=mask,
            return_dict=False,
        )
        # extract embeddings
        embedded_text_input = self.embedding_dropout(bert_embeddings)
        frame_embeddings = embedded_text_input[frame_indicator == 1]
        # get sizes
        batch_size, sequence_length, _ = embedded_text_input.size()
        # outputs
        logits = self.tag_projection_layer(embedded_text_input)
        frame_logits = self.frame_projection_layer(frame_embeddings)

        reshaped_log_probs = logits.view(-1, self.num_classes)
        class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(
            [batch_size, sequence_length, self.num_classes])

        frame_probabilities = F.softmax(frame_logits, dim=-1)
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.make_output_human_readable.
        output_dict = {
            "logits": logits,
            "frame_logits": frame_logits,
            "class_probabilities": class_probabilities,
            "frame_probabilities": frame_probabilities,
            "mask": mask,
        }
        # We add in the offsets here so we can compute the un-wordpieced tags.
        words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"])
                                      for x in metadata])
        lemmas = [l for x in metadata for l in x["lemmas"]]
        output_dict["words"] = list(words)
        output_dict["lemma"] = list(lemmas)
        output_dict["verb"] = list(verbs)
        output_dict["wordpiece_offsets"] = list(offsets)

        if tags is not None:
            # compute role loss
            role_loss = sequence_cross_entropy_with_logits(
                logits, tags, mask, label_smoothing=self._label_smoothing)
            # compute frame loss
            frame_tags_filtered = frame_tags[frame_indicator == 1]
            frame_loss = self.frame_criterion(frame_logits,
                                              frame_tags_filtered)
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [
                    example_metadata["verb_index"]
                    for example_metadata in metadata
                ]
                batch_sentences = [
                    example_metadata["words"] for example_metadata in metadata
                ]
                # Get the BIO tags from make_output_human_readable()
                batch_bio_predicted_tags = self.make_output_human_readable(
                    output_dict).pop("tags")
                from allennlp_models.structured_prediction.models.srl import (
                    convert_bio_tags_to_conll_format, )

                batch_conll_predicted_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_predicted_tags
                ]
                batch_bio_gold_tags = [
                    example_metadata["gold_tags"]
                    for example_metadata in metadata
                ]
                batch_conll_gold_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_gold_tags
                ]
                self.span_metric(
                    batch_verb_indices,
                    batch_sentences,
                    batch_conll_predicted_tags,
                    batch_conll_gold_tags,
                )
            self.f1_frame_metric(frame_logits, frame_tags_filtered)
            output_dict["frame_loss"] = frame_loss
            output_dict["role_loss"] = role_loss
            output_dict["loss"] = (role_loss + frame_loss) / 2
        return output_dict
    def test_srl_eval_correctly_scores_identical_tags(self):
        batch_verb_indices = [3, 8, 2, 0]
        batch_sentences = [
            [
                "Mali",
                "government",
                "officials",
                "say",
                "the",
                "woman",
                "'s",
                "confession",
                "was",
                "forced",
                ".",
            ],
            [
                "Mali",
                "government",
                "officials",
                "say",
                "the",
                "woman",
                "'s",
                "confession",
                "was",
                "forced",
                ".",
            ],
            [
                "The",
                "prosecution",
                "rested",
                "its",
                "case",
                "last",
                "month",
                "after",
                "four",
                "months",
                "of",
                "hearings",
                ".",
            ],
            ["Come", "in", "and", "buy", "."],
        ]
        batch_bio_predicted_tags = [
            [
                "B-ARG0",
                "I-ARG0",
                "I-ARG0",
                "B-V",
                "B-ARG1",
                "I-ARG1",
                "I-ARG1",
                "I-ARG1",
                "I-ARG1",
                "I-ARG1",
                "O",
            ],
            [
                "O", "O", "O", "O", "B-ARG1", "I-ARG1", "I-ARG1", "I-ARG1",
                "B-V", "B-ARG2", "O"
            ],
            [
                "B-ARG0",
                "I-ARG0",
                "B-V",
                "B-ARG1",
                "I-ARG1",
                "B-ARGM-TMP",
                "I-ARGM-TMP",
                "B-ARGM-TMP",
                "I-ARGM-TMP",
                "I-ARGM-TMP",
                "I-ARGM-TMP",
                "I-ARGM-TMP",
                "O",
            ],
            ["B-V", "B-AM-DIR", "O", "O", "O"],
        ]
        from allennlp_models.structured_prediction.models.srl import (
            convert_bio_tags_to_conll_format, )

        batch_conll_predicted_tags = [
            convert_bio_tags_to_conll_format(tags)
            for tags in batch_bio_predicted_tags
        ]
        batch_bio_gold_tags = [
            [
                "B-ARG0",
                "I-ARG0",
                "I-ARG0",
                "B-V",
                "B-ARG1",
                "I-ARG1",
                "I-ARG1",
                "I-ARG1",
                "I-ARG1",
                "I-ARG1",
                "O",
            ],
            [
                "O", "O", "O", "O", "B-ARG1", "I-ARG1", "I-ARG1", "I-ARG1",
                "B-V", "B-ARG2", "O"
            ],
            [
                "B-ARG0",
                "I-ARG0",
                "B-V",
                "B-ARG1",
                "I-ARG1",
                "B-ARGM-TMP",
                "I-ARGM-TMP",
                "B-ARGM-TMP",
                "I-ARGM-TMP",
                "I-ARGM-TMP",
                "I-ARGM-TMP",
                "I-ARGM-TMP",
                "O",
            ],
            ["B-V", "B-AM-DIR", "O", "O", "O"],
        ]
        batch_conll_gold_tags = [
            convert_bio_tags_to_conll_format(tags)
            for tags in batch_bio_gold_tags
        ]

        srl_scorer = SrlEvalScorer(ignore_classes=["V"])
        srl_scorer(batch_verb_indices, batch_sentences,
                   batch_conll_predicted_tags, batch_conll_gold_tags)
        metrics = srl_scorer.get_metric()
        assert len(metrics) == 18
        assert_allclose(metrics["precision-ARG0"], 1.0)
        assert_allclose(metrics["recall-ARG0"], 1.0)
        assert_allclose(metrics["f1-measure-ARG0"], 1.0)
        assert_allclose(metrics["precision-ARG1"], 1.0)
        assert_allclose(metrics["recall-ARG1"], 1.0)
        assert_allclose(metrics["f1-measure-ARG1"], 1.0)
        assert_allclose(metrics["precision-ARG2"], 1.0)
        assert_allclose(metrics["recall-ARG2"], 1.0)
        assert_allclose(metrics["f1-measure-ARG2"], 1.0)
        assert_allclose(metrics["precision-ARGM-TMP"], 1.0)
        assert_allclose(metrics["recall-ARGM-TMP"], 1.0)
        assert_allclose(metrics["f1-measure-ARGM-TMP"], 1.0)
        assert_allclose(metrics["precision-AM-DIR"], 1.0)
        assert_allclose(metrics["recall-AM-DIR"], 1.0)
        assert_allclose(metrics["f1-measure-AM-DIR"], 1.0)
        assert_allclose(metrics["precision-overall"], 1.0)
        assert_allclose(metrics["recall-overall"], 1.0)
        assert_allclose(metrics["f1-measure-overall"], 1.0)
Beispiel #6
0
    def forward(  # type: ignore
            self,
            tokens: TextFieldTensors,
            verb_indicator: torch.Tensor,
            sentence_end: torch.LongTensor,
            metadata: List[Any],
            tags: torch.LongTensor = None,
            offsets: torch.LongTensor = None):
        """
        # Parameters

        tokens : `TextFieldTensors`, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which
            indexes wordpieces from the BERT vocabulary.
        verb_indicator: `torch.LongTensor`, required.
            An integer `SequenceFeatureField` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : `torch.LongTensor`, optional (default = `None`)
            A torch tensor representing the sequence of integer gold class labels
            of shape `(batch_size, num_tokens)`
        metadata : `List[Dict[str, Any]]`, optional, (default = `None`)
            metadata containing the original words in the sentence, the verb to compute the
            frame for, and start offsets for converting wordpieces back to a sequence of words,
            under 'words', 'verb' and 'offsets' keys, respectively.

        # Returns

        An output dictionary consisting of:
        logits : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            a distribution of the tag classes per word.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """

        if isinstance(self.bert_model,
                      PretrainedTransformerMismatchedEmbedder):
            encoder_inputs = tokens["tokens"]
            if self.bert_config.type_vocab_size > 1:
                encoder_inputs["type_ids"] = verb_indicator
            encoded_text = self.bert_model(**encoder_inputs)
            batch_size = encoded_text.shape[0]
            if self.bert_config.type_vocab_size == 1:
                verb_embeddings = encoded_text[
                    torch.arange(batch_size).to(encoded_text.device),
                    verb_indicator.argmax(1), :]
                verb_embeddings = torch.where(
                    (verb_indicator.sum(1, keepdim=True) > 0).repeat(
                        1, verb_embeddings.shape[-1]), verb_embeddings,
                    torch.zeros_like(verb_embeddings))
                encoded_text = torch.cat(
                    (encoded_text, verb_embeddings.unsqueeze(1).repeat(
                        1, encoded_text.shape[1], 1)),
                    dim=2)
            mask = tokens["tokens"]["mask"]
            index = mask.sum(1).argmax().item()
            # print(mask.shape, encoded_text.shape, tokens["tokens"]["token_ids"].shape, tags.shape, max([len(x['words']) for x in metadata]), mask.sum(1)[index].item())
            # print(tokens["tokens"]["token_ids"][index,:])
        else:
            mask = get_text_field_mask(tokens)
            bert_embeddings, _ = self.bert_model(
                input_ids=util.get_token_ids_from_text_field_tensors(tokens),
                # token_type_ids=verb_indicator,
                attention_mask=mask,
            )

            batch_size, _ = mask.size()
            embedded_text_input = self.embedding_dropout(bert_embeddings)
            # Restrict to sentence part
            sentence_mask = (torch.arange(mask.shape[1]).unsqueeze(0).repeat(
                batch_size, 1).to(mask.device) <
                             sentence_end.unsqueeze(1).repeat(
                                 1, mask.shape[1])).long()
            cutoff = sentence_end.max().item()
            if self._encoder is None:
                encoded_text = embedded_text_input
                mask = sentence_mask[:, :cutoff].contiguous()
                encoded_text = encoded_text[:, :cutoff, :]
                tags = tags[:, :cutoff].contiguous()
            else:
                predicate_embeddings = self.predicate_embedding(verb_indicator)
                encoder_inputs = torch.cat(
                    (embedded_text_input, predicate_embeddings), dim=-1)
                encoded_text = self._encoder(encoder_inputs,
                                             mask=sentence_mask.bool())
                # print(verb_indicator)
                predicate_index = (verb_indicator * torch.arange(
                    start=verb_indicator.shape[-1] - 1, end=-1,
                    step=-1).to(mask.device).unsqueeze(0).repeat(
                        batch_size, 1)).argmax(1)
                # print(predicate_index)
                predicate_hidden = encoded_text[
                    torch.arange(batch_size).to(mask.device), predicate_index]
                predicate_exists, _ = verb_indicator.max(1)
                encoded_text = encoded_text[:, :cutoff, :]
                tags = tags[:, :cutoff].contiguous()
                mask = sentence_mask[:, :cutoff].contiguous()
                predicate_exists = predicate_exists.unsqueeze(1).repeat(
                    1, encoded_text.shape[-1])
                predicate_hidden = torch.where(
                    predicate_exists > 0, predicate_hidden,
                    torch.zeros_like(predicate_hidden))
                encoded_text = torch.cat(
                    (encoded_text, predicate_hidden.unsqueeze(1).repeat(
                        1, encoded_text.shape[1], 1)),
                    dim=-1)

        sequence_length = encoded_text.shape[1]
        logits = self.tag_projection_layer(encoded_text)
        # print(mask, logits)
        if self._lp and sequence_length <= 100:
            eps = 1e-4
            Q = eps * torch.eye(
                sequence_length * self.num_classes,
                sequence_length * self.num_classes).unsqueeze(0).repeat(
                    batch_size, 1, 1).to(logits.device).float()
            p = logits.view(batch_size, -1)
            G = -1 * torch.eye(
                sequence_length * self.num_classes).unsqueeze(0).repeat(
                    batch_size, 1, 1).to(logits.device).float()
            h = torch.zeros_like(p)
            A = torch.arange(sequence_length *
                             self.num_classes).unsqueeze(0).repeat(
                                 sequence_length, 1)
            A2 = torch.arange(sequence_length).unsqueeze(1).repeat(
                1, sequence_length * self.num_classes) * self.num_classes
            A = torch.where((A >= A2) & (A < A2 + self.num_classes),
                            torch.ones_like(A), torch.zeros_like(A))
            A = A.unsqueeze(0).repeat(batch_size, 1,
                                      1).to(logits.device).float()
            b = torch.ones_like(A[:, :, 0])
            probs = QPFunction()(Q, p, torch.autograd.Variable(torch.Tensor()),
                                 torch.autograd.Variable(torch.Tensor()), A, b)
            probs = probs.view(batch_size, sequence_length, self.num_classes)
            """logits_shape = logits.shape
            logits = torch.where(mask.bool().unsqueeze(-1).repeat(1, 1, logits.shape[-1]), logits, logits-10000)
            max_sequence_length = min([l for l in self.lengths if l >= sequence_length])
            if max_sequence_length > logits_shape[1]:
                logits = torch.cat((logits, torch.zeros((batch_size, max_sequence_length-logits_shape[1], logits_shape[2])).to(logits.device)), dim=1)
            lp_layer = self._layer_list[self.length_map[max_sequence_length]]
            probs, = lp_layer(logits)
            print(torch.isnan(probs).any())
            if max_sequence_length > logits_shape[1]:
                probs = probs[:,:logits_shape[1],:]"""
            logits = (torch.nn.functional.relu(probs) + 1e-4).log()
        if self._lpsmap:
            if self._lpsmap_core_only:
                all_logits = logits
            else:
                all_logits = torch.cat((logits, 0.5 * torch.ones(
                    (batch_size, 1, logits.shape[-1])).to(logits.device)),
                                       dim=1)
            probs = []
            for i in range(batch_size):
                if self.constrain_crf_decoding:
                    unaries = logits[i, :, :].view(-1).cpu()
                    additionals = self.crf.transitions.view(-1).repeat(
                        sequence_length) + 10000 * (
                            self.crf._constraint_mask[:-2, :-2] -
                            1).view(-1).repeat(sequence_length)
                    start_transitions = self.crf.start_transitions + 10000 * (
                        self.crf._constraint_mask[-2, :-2] - 1)
                    end_transitions = self.crf.start_transitions + 10000 * (
                        self.crf._constraint_mask[-1, :-2] - 1)
                    additionals = torch.cat(
                        (additionals, start_transitions, end_transitions),
                        dim=0).cpu()
                    fg = TorchFactorGraph()
                    x = fg.variable_from(unaries)
                    f = PFactorSequence()

                    f.initialize(
                        [self.num_classes for _ in range(sequence_length)])
                    factor = TorchOtherFactor(f, x, additionals)
                    fg.add(factor)
                    # add budget constraint for each state
                    for state in self._core_roles:
                        vars_state = x[state::self.num_classes]
                        fg.add(AtMostOne(vars_state))
                    # solve SparseMAP
                    fg.solve(max_iter=200)
                    probs.append(
                        unaries.to(logits.device).view(sequence_length,
                                                       self.num_classes))
                else:
                    fg = TorchFactorGraph()
                    x = fg.variable_from(all_logits[i, :, :].cpu())
                    for j in range(sequence_length):
                        fg.add(Xor(x[j, :]))
                    for j in self._core_roles:
                        fg.add(AtMostOne(x[:sequence_length, j]))
                    if not self._lpsmap_core_only:
                        full_sequence = list(range(sequence_length))
                        base_roles = set([
                            second
                            for (_, second) in self._r_roles + self._c_roles
                        ])
                        """for (r_role, base_role) in self._r_roles+self._c_roles:
                            for j in range(sequence_length):
                                fg.add(Imply(x[full_sequence+[j],[base_role]*sequence_length+[r_role]], negated=[True]*(sequence_length+1)))"""
                        for base_role in base_roles:
                            fg.add(OrOut(x[:, base_role]))
                        for (r_role,
                             base_role) in self._r_roles + self._c_roles:
                            fg.add(OrOut(x[:, r_role]))
                            fg.add(
                                Or(x[[sequence_length, sequence_length],
                                     [r_role, base_role]],
                                   negated=[True, False]))
                    max_iter = 100
                    if not self._lpsmap_core_only:
                        max_iter = min(max_iter, 400)
                    elif (not self.training) and not self._val_inference:
                        max_iter = min(max_iter, 200)
                    fg.solve(max_iter=max_iter)
                    probs.append(x.value[:sequence_length, :].contiguous().to(
                        logits.device))
            class_probabilities = torch.stack(probs)
            # class_probabilities = self.lpsmap(logits)
            max_seq_length = 200
            # if self.lpsmap is None:
            """with torch.no_grad():
                # self.lpsmap = LpSparseMap(num_rows=sequence_length, num_cols=self.num_classes, batch_size=batch_size, device=logits.device, constraints=[('xor', ('row', list(range(sequence_length)))), ('budget', ('col', self._core_roles))])
                max_iter = 1000
                constraint_types = ["xor", "budget"]
                constraint_dims = ["row", "col"]
                constraint_sets = [list(range(sequence_length)), self._core_roles]
                class_probabilities = lpsmap(logits, constraint_types, constraint_dims, constraint_sets, max_iter)
                # if max_seq_length > sequence_length:
                #     logits = torch.cat((logits, -9999.*torch.ones((batch_size, max_seq_length-sequence_length, self.num_classes)).to(logits.device)), dim=1)
                # class_probabilities = self.lpsmap.solve(logits, max_iter=max_iter)"""
            # logits = (class_probabilities+1e-4).log()
        else:
            reshaped_log_probs = logits.view(-1, self.num_classes)
            class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(
                [batch_size, sequence_length, self.num_classes])
        output_dict = {
            "logits": logits,
            "class_probabilities": class_probabilities
        }
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.make_output_human_readable.
        output_dict["mask"] = mask
        # We add in the offsets here so we can compute the un-wordpieced tags.
        words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"])
                                      for x in metadata])
        output_dict["words"] = list(words)
        output_dict["verb"] = list(verbs)
        output_dict["wordpiece_offsets"] = list(offsets)

        if tags is not None:
            # print(mask.shape, tags.shape, logits.shape, tags.max(), tags.min())
            if self._lpsmap:
                loss = LpsmapLoss.apply(logits, class_probabilities, tags,
                                        mask)
                # tags_1hot = torch.zeros_like(class_probabilities).scatter_(2, tags.unsqueeze(-1), torch.ones_like(class_probabilities))
                # loss = -(tags_1hot*class_probabilities*mask.unsqueeze(-1).repeat(1, 1, class_probabilities.shape[-1])).sum()
            elif self.constrain_crf_decoding:
                loss = -self.crf(logits, tags, mask)
            else:
                loss = sequence_cross_entropy_with_logits(
                    logits, tags, mask, label_smoothing=self._label_smoothing)
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [
                    example_metadata["verb_index"]
                    for example_metadata in metadata
                ]
                batch_sentences = [
                    example_metadata["words"] for example_metadata in metadata
                ]
                # Get the BIO tags from make_output_human_readable()
                # TODO (nfliu): This is kind of a hack, consider splitting out part
                # of make_output_human_readable() to a separate function.
                batch_bio_predicted_tags = self.make_output_human_readable(
                    output_dict).pop("tags")
                from allennlp_models.structured_prediction.models.srl import (
                    convert_bio_tags_to_conll_format, )

                if self.constrain_crf_decoding and not self._lpsmap:
                    batch_conll_predicted_tags = [
                        convert_bio_tags_to_conll_format([
                            self.vocab.get_token_from_index(
                                tag, namespace=self._label_namespace)
                            for tag in seq
                        ]) for (seq, _) in self.crf.viterbi_tags(logits, mask)
                    ]
                else:
                    batch_conll_predicted_tags = [
                        convert_bio_tags_to_conll_format(tags)
                        for tags in batch_bio_predicted_tags
                    ]
                batch_bio_gold_tags = [
                    example_metadata["gold_tags"]
                    for example_metadata in metadata
                ]
                # print(batch_bio_gold_tags)
                batch_conll_gold_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_gold_tags
                ]
                self.span_metric(
                    batch_verb_indices,
                    batch_sentences,
                    batch_conll_predicted_tags,
                    batch_conll_gold_tags,
                )
            output_dict["loss"] = loss
            output_dict["gold_tags"] = [x["gold_tags"] for x in metadata]
        return output_dict
Beispiel #7
0
    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        verb_indicator: torch.Tensor,
        sentence_end: torch.LongTensor,
        spans: torch.LongTensor,
        span_labels: torch.LongTensor,
        metadata: List[Any],
        tags: torch.LongTensor = None,
    ):
        """
        # Parameters

        tokens : `TextFieldTensors`, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which
            indexes wordpieces from the BERT vocabulary.
        verb_indicator: `torch.LongTensor`, required.
            An integer `SequenceFeatureField` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : `torch.LongTensor`, optional (default = `None`)
            A torch tensor representing the sequence of integer gold class labels
            of shape `(batch_size, num_tokens)`
        metadata : `List[Dict[str, Any]]`, optional, (default = `None`)
            metadata containg the original words in the sentence, the verb to compute the
            frame for, and start offsets for converting wordpieces back to a sequence of words,
            under 'words', 'verb' and 'offsets' keys, respectively.

        # Returns

        An output dictionary consisting of:
        logits : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            a distribution of the tag classes per word.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """
        mask = get_text_field_mask(tokens)
        start = time.time()
        bert_embeddings, _ = self.bert_model(
            input_ids=util.get_token_ids_from_text_field_tensors(tokens),
            # token_type_ids=verb_indicator,
            attention_mask=mask,
        )

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1)
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        embedded_text_input = self.embedding_dropout(bert_embeddings)
        batch_size, sequence_length, _ = embedded_text_input.size()
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            bert_embeddings, spans)

        if self._context_layer is not None:
            contextualized_embeddings = self._context_layer(
                embedded_text_input, mask)
            # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
            endpoint_span_embeddings = self._endpoint_span_extractor(
                contextualized_embeddings, spans)

            # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
            # span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1)
            span_embeddings = endpoint_span_embeddings
        else:
            span_embeddings = attended_span_embeddings

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * sequence_length))
        num_spans = spans.shape[1]
        num_spans_to_keep = min(num_spans_to_keep, num_spans)

        # Shape: (batch_size, num_spans)
        span_mention_scores = self._mention_scorer(
            self._mention_feedforward(span_embeddings)).squeeze(-1)
        # Shape: (batch_size, num_spans) for all 3 tensors
        top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk(
            span_mention_scores, span_mask, num_spans_to_keep)
        verb_index = verb_indicator.argmax(1).unsqueeze(1).unsqueeze(2).repeat(
            1, 1, embedded_text_input.shape[-1])
        verb_embeddings = torch.gather(embedded_text_input, 1, verb_index)
        assert len(
            verb_embeddings.shape) == 3 and verb_embeddings.shape[1] == 1
        verb_embeddings = verb_embeddings.squeeze(1)
        # print(verb_indicator.sum(1, keepdim=True) > 0)
        verb_embeddings = torch.where(
            (verb_indicator.sum(1, keepdim=True) > 0).repeat(
                1, verb_embeddings.shape[-1]), verb_embeddings,
            torch.zeros_like(verb_embeddings))
        # print(verb_embeddings)
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, spans.shape[1])
        span_embeddings = util.batched_index_select(span_embeddings,
                                                    top_span_indices,
                                                    flat_top_span_indices)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)
        top_span_labels = util.batched_index_select(
            span_labels.unsqueeze(-1), top_span_indices,
            flat_top_span_indices).squeeze(-1)
        concatenated_span_embeddings = torch.cat(
            (span_embeddings, verb_embeddings.unsqueeze(1).repeat(
                1, span_embeddings.shape[1], 1)),
            dim=2)
        # print(concatenated_span_embeddings[:,:,:])
        hidden = self.hidden_layer(concatenated_span_embeddings)
        # print(hidden[1,:,:])
        # print(top_span_indices)
        # print([[span_mention_scores[i,top_span_indices[i,j]].item() for j in range(top_span_indices.shape[1])] for i in range(top_span_labels.shape[0])])
        # print(top_span_mention_scores, self.vocab.get_token_index("O", namespace="span_labels"))
        predictions = self.output_layer(hidden)
        # predictions += top_span_mention_scores.unsqueeze(-1).repeat(1, 1, self.num_classes-1)
        predictions = torch.cat(
            (torch.zeros_like(predictions[:, :, :1]), predictions), dim=-1)
        # print(top_span_mention_scores.unsqueeze(-1).repeat(1, 1, self.num_classes-1))

        output_dict = {}
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.make_output_human_readable.
        output_dict["mask"] = mask
        # We add in the offsets here so we can compute the un-wordpieced tags.
        words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"])
                                      for x in metadata])
        output_dict["words"] = list(words)
        output_dict["verb"] = list(verbs)
        output_dict["wordpiece_offsets"] = list(offsets)

        if tags is not None:
            loss = (self._ce_loss(predictions.view(-1, predictions.shape[-1]),
                                  top_span_labels.view(-1)) *
                    top_span_mask.float().view(-1)
                    ).sum() / top_span_mask.float().sum()
            # print(top_span_labels)
            # print(predictions.argmax(-1))
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [
                    example_metadata["verb_index"]
                    for example_metadata in metadata
                ]
                batch_sentences = [
                    example_metadata["words"] for example_metadata in metadata
                ]
                # Get the BIO tags from make_output_human_readable()
                # TODO (nfliu): This is kind of a hack, consider splitting out part
                # of make_output_human_readable() to a separate function.
                batch_bio_predicted_tags = self.get_tags(
                    top_spans, predictions, mask.shape[1], top_span_mask,
                    output_dict)
                from allennlp_models.structured_prediction.models.srl import (
                    convert_bio_tags_to_conll_format, )

                batch_conll_predicted_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_predicted_tags
                ]
                batch_bio_gold_tags = [
                    example_metadata["gold_tags"]
                    for example_metadata in metadata
                ]
                # print('G', batch_bio_gold_tags)
                batch_conll_gold_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_gold_tags
                ]
                self.span_metric(
                    batch_verb_indices,
                    batch_sentences,
                    batch_conll_predicted_tags,
                    batch_conll_gold_tags,
                )
            output_dict["loss"] = loss
        return output_dict