Exemplo n.º 1
0
    def test_span_metrics_are_computed_correctly(self):
        batch_verb_indices = [2]
        batch_sentences = [["The", "cat", "loves", "hats", "."]]
        batch_bio_predicted_tags = [["B-ARG0", "B-ARG1", "B-V", "B-ARG1", "O"]]
        batch_conll_predicted_tags = [
            convert_bio_tags_to_conll_format(tags) for tags in batch_bio_predicted_tags
        ]
        batch_bio_gold_tags = [["B-ARG0", "I-ARG0", "B-V", "B-ARG1", "O"]]
        batch_conll_gold_tags = [
            convert_bio_tags_to_conll_format(tags) for tags in batch_bio_gold_tags
        ]

        srl_scorer = SrlEvalScorer(ignore_classes=["V"])
        srl_scorer(
            batch_verb_indices, batch_sentences, batch_conll_predicted_tags, batch_conll_gold_tags
        )
        metrics = srl_scorer.get_metric()
        assert len(metrics) == 9
        assert_allclose(metrics["precision-ARG0"], 0.0)
        assert_allclose(metrics["recall-ARG0"], 0.0)
        assert_allclose(metrics["f1-measure-ARG0"], 0.0)
        assert_allclose(metrics["precision-ARG1"], 0.5)
        assert_allclose(metrics["recall-ARG1"], 1.0)
        assert_allclose(metrics["f1-measure-ARG1"], 2 / 3)
        assert_allclose(metrics["precision-overall"], 1 / 3)
        assert_allclose(metrics["recall-overall"], 1 / 2)
        assert_allclose(
            metrics["f1-measure-overall"], (2 * (1 / 3) * (1 / 2)) / ((1 / 3) + (1 / 2))
        )
Exemplo n.º 2
0
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.Tensor],
            verb_indicator: torch.Tensor,
            metadata: List[Any],
            tags: torch.LongTensor = None):
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. For this model, this must be a `SingleIdTokenIndexer` which
            indexes wordpieces from the BERT vocabulary.
        verb_indicator: torch.LongTensor, required.
            An integer ``SequenceFeatureField`` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels
            of shape ``(batch_size, num_tokens)``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            metadata containg the original words in the sentence, the verb to compute the
            frame for, and start offsets for converting wordpieces back to a sequence of words,
            under 'words', 'verb' and 'offsets' keys, respectively.

        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            a distribution of the tag classes per word.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        mask = get_text_field_mask(tokens)
        bert_embeddings, _ = self.bert_model(input_ids=tokens["tokens"],
                                             token_type_ids=verb_indicator,
                                             attention_mask=mask,
                                             output_all_encoded_layers=False)

        embedded_text_input = self.embedding_dropout(bert_embeddings)
        batch_size, sequence_length, _ = embedded_text_input.size()
        logits = self.tag_projection_layer(embedded_text_input)

        reshaped_log_probs = logits.view(-1, self.num_classes)
        class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(
            [batch_size, sequence_length, self.num_classes])
        output_dict = {
            "logits": logits,
            "class_probabilities": class_probabilities
        }
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.decode.
        output_dict["mask"] = mask
        # We add in the offsets here so we can compute the un-wordpieced tags.
        words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"])
                                      for x in metadata])
        output_dict["words"] = list(words)
        output_dict["verb"] = list(verbs)
        output_dict["wordpiece_offsets"] = list(offsets)

        if tags is not None:
            loss = sequence_cross_entropy_with_logits(
                logits, tags, mask, label_smoothing=self._label_smoothing)
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [
                    example_metadata["verb_index"]
                    for example_metadata in metadata
                ]
                batch_sentences = [
                    example_metadata["words"] for example_metadata in metadata
                ]
                # Get the BIO tags from decode()
                # TODO (nfliu): This is kind of a hack, consider splitting out part
                # of decode() to a separate function.
                batch_bio_predicted_tags = self.decode(output_dict).pop("tags")
                batch_conll_predicted_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_predicted_tags
                ]
                batch_bio_gold_tags = [
                    example_metadata["gold_tags"]
                    for example_metadata in metadata
                ]
                batch_conll_gold_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_gold_tags
                ]
                self.span_metric(batch_verb_indices, batch_sentences,
                                 batch_conll_predicted_tags,
                                 batch_conll_gold_tags)
            output_dict["loss"] = loss
        return output_dict
Exemplo n.º 3
0
    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        verb_indicator: torch.LongTensor,
        tags: torch.LongTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:

        """
        # Parameters

        tokens : TextFieldTensors, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. This output is a dictionary mapping keys to `TokenIndexer`
            tensors.  At its most basic, using a `SingleIdTokenIndexer` this is : `{"tokens":
            Tensor(batch_size, num_tokens)}`. This dictionary will have the same keys as were used
            for the `TokenIndexers` when you created the `TextField` representing your
            sequence.  The dictionary is designed to be passed directly to a `TextFieldEmbedder`,
            which knows how to combine different word representations into a single vector per
            token in your input.
        verb_indicator: torch.LongTensor, required.
            An integer `SequenceFeatureField` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels
            of shape `(batch_size, num_tokens)`
        metadata : `List[Dict[str, Any]]`, optional, (default = None)
            metadata containg the original words in the sentence and the verb to compute the
            frame for, under 'words' and 'verb' keys, respectively.

        # Returns

        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : torch.FloatTensor
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            a distribution of the tag classes per word.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.

        """
        embedded_text_input = self.embedding_dropout(self.text_field_embedder(tokens))
        mask = get_text_field_mask(tokens)
        embedded_verb_indicator = self.binary_feature_embedding(verb_indicator.long())
        # Concatenate the verb feature onto the embedded text. This now
        # has shape (batch_size, sequence_length, embedding_dim + binary_feature_dim).
        embedded_text_with_verb_indicator = torch.cat(
            [embedded_text_input, embedded_verb_indicator], -1
        )
        batch_size, sequence_length, _ = embedded_text_with_verb_indicator.size()

        encoded_text = self.encoder(embedded_text_with_verb_indicator, mask)

        logits = self.tag_projection_layer(encoded_text)
        reshaped_log_probs = logits.view(-1, self.num_classes)
        class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(
            [batch_size, sequence_length, self.num_classes]
        )
        output_dict = {"logits": logits, "class_probabilities": class_probabilities}
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.decode.
        output_dict["mask"] = mask

        if tags is not None:
            loss = sequence_cross_entropy_with_logits(
                logits, tags, mask, label_smoothing=self._label_smoothing
            )
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [
                    example_metadata["verb_index"] for example_metadata in metadata
                ]
                batch_sentences = [example_metadata["words"] for example_metadata in metadata]
                # Get the BIO tags from decode()
                # TODO (nfliu): This is kind of a hack, consider splitting out part
                # of decode() to a separate function.
                batch_bio_predicted_tags = self.decode(output_dict).pop("tags")
                batch_conll_predicted_tags = [
                    convert_bio_tags_to_conll_format(tags) for tags in batch_bio_predicted_tags
                ]
                batch_bio_gold_tags = [
                    example_metadata["gold_tags"] for example_metadata in metadata
                ]
                batch_conll_gold_tags = [
                    convert_bio_tags_to_conll_format(tags) for tags in batch_bio_gold_tags
                ]
                self.span_metric(
                    batch_verb_indices,
                    batch_sentences,
                    batch_conll_predicted_tags,
                    batch_conll_gold_tags,
                )
            output_dict["loss"] = loss

        words, verbs = zip(*[(x["words"], x["verb"]) for x in metadata])
        if metadata is not None:
            output_dict["words"] = list(words)
            output_dict["verb"] = list(verbs)
        return output_dict
Exemplo n.º 4
0
 def test_bio_tags_correctly_convert_to_conll_format(self):
     bio_tags = ["B-ARG-1", "I-ARG-1", "O", "B-V", "B-ARGM-ADJ", "O"]
     conll_tags = convert_bio_tags_to_conll_format(bio_tags)
     assert conll_tags == ["(ARG-1*", "*)", "*", "(V*)", "(ARGM-ADJ*)", "*"]
Exemplo n.º 5
0
    def test_srl_eval_correctly_scores_identical_tags(self):
        batch_verb_indices = [3, 8, 2, 0]
        batch_sentences = [
            [
                "Mali",
                "government",
                "officials",
                "say",
                "the",
                "woman",
                "'s",
                "confession",
                "was",
                "forced",
                ".",
            ],
            [
                "Mali",
                "government",
                "officials",
                "say",
                "the",
                "woman",
                "'s",
                "confession",
                "was",
                "forced",
                ".",
            ],
            [
                "The",
                "prosecution",
                "rested",
                "its",
                "case",
                "last",
                "month",
                "after",
                "four",
                "months",
                "of",
                "hearings",
                ".",
            ],
            ["Come", "in", "and", "buy", "."],
        ]
        batch_bio_predicted_tags = [
            [
                "B-ARG0",
                "I-ARG0",
                "I-ARG0",
                "B-V",
                "B-ARG1",
                "I-ARG1",
                "I-ARG1",
                "I-ARG1",
                "I-ARG1",
                "I-ARG1",
                "O",
            ],
            ["O", "O", "O", "O", "B-ARG1", "I-ARG1", "I-ARG1", "I-ARG1", "B-V", "B-ARG2", "O"],
            [
                "B-ARG0",
                "I-ARG0",
                "B-V",
                "B-ARG1",
                "I-ARG1",
                "B-ARGM-TMP",
                "I-ARGM-TMP",
                "B-ARGM-TMP",
                "I-ARGM-TMP",
                "I-ARGM-TMP",
                "I-ARGM-TMP",
                "I-ARGM-TMP",
                "O",
            ],
            ["B-V", "B-AM-DIR", "O", "O", "O"],
        ]
        batch_conll_predicted_tags = [
            convert_bio_tags_to_conll_format(tags) for tags in batch_bio_predicted_tags
        ]
        batch_bio_gold_tags = [
            [
                "B-ARG0",
                "I-ARG0",
                "I-ARG0",
                "B-V",
                "B-ARG1",
                "I-ARG1",
                "I-ARG1",
                "I-ARG1",
                "I-ARG1",
                "I-ARG1",
                "O",
            ],
            ["O", "O", "O", "O", "B-ARG1", "I-ARG1", "I-ARG1", "I-ARG1", "B-V", "B-ARG2", "O"],
            [
                "B-ARG0",
                "I-ARG0",
                "B-V",
                "B-ARG1",
                "I-ARG1",
                "B-ARGM-TMP",
                "I-ARGM-TMP",
                "B-ARGM-TMP",
                "I-ARGM-TMP",
                "I-ARGM-TMP",
                "I-ARGM-TMP",
                "I-ARGM-TMP",
                "O",
            ],
            ["B-V", "B-AM-DIR", "O", "O", "O"],
        ]
        batch_conll_gold_tags = [
            convert_bio_tags_to_conll_format(tags) for tags in batch_bio_gold_tags
        ]

        srl_scorer = SrlEvalScorer(ignore_classes=["V"])
        srl_scorer(
            batch_verb_indices, batch_sentences, batch_conll_predicted_tags, batch_conll_gold_tags
        )
        metrics = srl_scorer.get_metric()
        assert len(metrics) == 18
        assert_allclose(metrics["precision-ARG0"], 1.0)
        assert_allclose(metrics["recall-ARG0"], 1.0)
        assert_allclose(metrics["f1-measure-ARG0"], 1.0)
        assert_allclose(metrics["precision-ARG1"], 1.0)
        assert_allclose(metrics["recall-ARG1"], 1.0)
        assert_allclose(metrics["f1-measure-ARG1"], 1.0)
        assert_allclose(metrics["precision-ARG2"], 1.0)
        assert_allclose(metrics["recall-ARG2"], 1.0)
        assert_allclose(metrics["f1-measure-ARG2"], 1.0)
        assert_allclose(metrics["precision-ARGM-TMP"], 1.0)
        assert_allclose(metrics["recall-ARGM-TMP"], 1.0)
        assert_allclose(metrics["f1-measure-ARGM-TMP"], 1.0)
        assert_allclose(metrics["precision-AM-DIR"], 1.0)
        assert_allclose(metrics["recall-AM-DIR"], 1.0)
        assert_allclose(metrics["f1-measure-AM-DIR"], 1.0)
        assert_allclose(metrics["precision-overall"], 1.0)
        assert_allclose(metrics["recall-overall"], 1.0)
        assert_allclose(metrics["f1-measure-overall"], 1.0)
Exemplo n.º 6
0
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                verb_indicator: torch.LongTensor,
                adj: torch.LongTensor,
                tags: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        verb_indicator: torch.LongTensor, required.
            An integer ``SequenceFeatureField`` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels
            of shape ``(batch_size, num_tokens)``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            metadata containg the original words in the sentence and the verb to compute the
            frame for, under 'words' and 'verb' keys, respectively.
        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            a distribution of the tag classes per word.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_text_input = self.embedding_dropout(self.text_field_embedder(tokens))
        mask = get_text_field_mask(tokens)
        embedded_verb_indicator = self.binary_feature_embedding(verb_indicator.long())
        # Concatenate the verb feature onto the embedded text. This now
        # has shape (batch_size, sequence_length, embedding_dim + binary_feature_dim).
        embedded_text_with_verb_indicator = torch.cat([embedded_text_input, embedded_verb_indicator], -1)
        batch_size, sequence_length, _ = embedded_text_with_verb_indicator.size()

        encoded_text = self.encoder(embedded_text_with_verb_indicator, mask)
        logits = torch.Tensor()
        for i, j in enumerate(adj):
            gcn_output = self.gcn_layer(encoded_text[i], j)
            gcn_output = gcn_output.expand(1, -1, -1)
            logits = torch.cat((logits, gcn_output))
        #######################################
        # print(adj)
        # print(adj['elmo'])
        # for j, i in enumerate(adj['tokens']):
        #     # numpy can map, but tensor can not map
        #     idx = tokens['tokens'][j].cpu().numpy()
        #     # idx = np.array(idx, dtype=np.int32)
        #     id_map = {a: b for b, a in enumerate(idx)}
        #     _i = i.cpu().numpy()
        #     _i = _i.ravel()[np.flatnonzero(_i.flatten())]
        #     adj_map = np.array(list(map(id_map.get, _i))).reshape(-1, 2)
        #     # delete_list = []
        #     # for k in range(adj_map.shape[0]):
        #     #     if None in adj_map[k]:
        #     #         delete_list.append(k)
        #     # adj_map = np.delete(adj_map, delete_list, 0)
        #     # # because the difference between dependency parsing word and sentence word like U. M's
        #     # # maybe the word in dependency parsing is different with the sentence vocabulary,
        #     # # so there will be None in adj_map
        #     # # be careful with the copy, do not use =
        #     # tmp = adj_map[:]
        #     # len_tmp = len(tmp)
        #     # for k in range(len_tmp):
        #     #     if None not in adj_map:
        #     #         break
        #     #     if tmp[k] == None:
        #     #         if tmp[k-1] == None:
        #     #             continue
        #     #         if k%2:
        #     #             adj_map.remove(tmp[k])
        #     #             adj_map.remove(tmp[k-1])
        #     #         else:
        #     #             adj_map.remove(tmp[k])
        #     #             adj_map.remove(tmp[k+1])
        #
        #     edges = np.array(adj_map, dtype=np.int32).reshape(-1, 2)
        #     single_adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
        #                                shape=(tags.shape[1], tags.shape[1]), dtype=np.float32)
        #
        #     single_adj = single_adj + single_adj.T.multiply(single_adj.T > single_adj) - single_adj.multiply(
        #         single_adj.T > single_adj)
        #     single_adj = self.normalize(single_adj + sp.eye(single_adj.shape[0]))
        #     single_adj = self.sparse_mx_to_torch_sparse_tensor(single_adj)
        #
        #     gcn_output = self.gcn_layer(encoded_text[j], single_adj)
        #
        #     gcn_output = gcn_output.expand(1, -1, -1)
        #     logits = torch.cat((logits, gcn_output))
        #########################################
        # logits = self.tag_projection_layer(encoded_text)
        logits = self.decoder(logits, mask)
        logits = self.tag_projection_layer(logits)
        reshaped_log_probs = logits.view(-1, self.num_classes)
        class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view([batch_size,
                                                                          sequence_length,
                                                                          self.num_classes])
        output_dict = {"logits": logits, "class_probabilities": class_probabilities}
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.decode.
        output_dict["mask"] = mask

        if tags is not None:
            loss = sequence_cross_entropy_with_logits(logits,
                                                      tags,
                                                      mask,
                                                      label_smoothing=self._label_smoothing)
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [example_metadata["verb_index"] for example_metadata in metadata]
                batch_sentences = [example_metadata["words"] for example_metadata in metadata]
                # Get the BIO tags from decode()
                # TODO (nfliu): This is kind of a hack, consider splitting out part
                # of decode() to a separate function.
                batch_bio_predicted_tags = self.decode(output_dict).pop("tags")
                batch_conll_predicted_tags = [convert_bio_tags_to_conll_format(tags) for
                                              tags in batch_bio_predicted_tags]
                batch_bio_gold_tags = [example_metadata["gold_tags"] for example_metadata in metadata]
                batch_conll_gold_tags = [convert_bio_tags_to_conll_format(tags) for
                                         tags in batch_bio_gold_tags]
                self.span_metric(batch_verb_indices,
                                 batch_sentences,
                                 batch_conll_predicted_tags,
                                 batch_conll_gold_tags)
            output_dict["loss"] = loss

        if metadata is not None:
            words, verbs = zip(*[(x["words"], x["verb"]) for x in metadata])
            output_dict["words"] = list(words)
            output_dict["verb"] = list(verbs)
        return output_dict
Exemplo n.º 7
0
    def forward(  # type: ignore
        self,
        tokens: Dict[str, torch.LongTensor],
        verb_indicator: torch.LongTensor,
        img_emb,
        tags: torch.LongTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        image_embedding: (batch_size, image_embedding_size)
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        verb_indicator: torch.LongTensor, required.
            An integer ``SequenceFeatureField`` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels
            of shape ``(batch_size, num_tokens)``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            metadata containg the original words in the sentence and the verb to compute the
            frame for, under 'words' and 'verb' keys, respectively.

        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            unnormalised log probabilities of the tag classes.your own data in 
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            a distribution of the tag classes per word.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.

        """
        #import ipdb; ipdb.set_trace()
        embedded_text_input = self.embedding_dropout(
            self.text_field_embedder(tokens))
        mask = get_text_field_mask(tokens)
        embedded_verb_indicator = self.binary_feature_embedding(
            verb_indicator.long())
        # Concatenate the verb feature onto the embedded text. This now
        # has shape (batch_size, sequence_length, embedding_dim + binary_feature_dim).
        embedded_text_with_verb_indicator = torch.cat(
            [embedded_text_input, embedded_verb_indicator], -1)
        batch_size, sequence_length, _ = embedded_text_with_verb_indicator.size(
        )
        #TODO I need to change here
        encoded_text = self.encoder(embedded_text_with_verb_indicator, mask)
        # get final states of shape (batch, embedding_size)
        # ! this is commmented out
        # final_states = get_final_encoder_states(encoded_text, mask)
        # not sure about this
        if torch.cuda.is_available():
            self.img_enc.cuda()
        image_embedding_resized = self.img_enc(img_emb)
        # ! attention compute
        atts = self.attention(encoded_text, image_embedding_resized)
        # now compute the alignment loss.
        #! the atts (batch, tex_seq_length, image_obj_num)
        # ipdb.set_trace()
        atts = masked_softmax(atts, mask.unsqueeze(2), dim=2)
        # ? masked version?
        # ipdb.set_trace()
        contexts = torch.bmm(atts, image_embedding_resized)
        att_code = torch.cat([encoded_text, contexts], 2)

        # done normolzied the attention
        # done compute context
        # ? put it in model

        logits = self.tag_projection_layer(att_code)

        reshaped_log_probs = logits.view(-1, self.num_classes)
        class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(
            [batch_size, sequence_length, self.num_classes])
        output_dict = {
            "logits": logits,
            "class_probabilities": class_probabilities
        }
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.decode.
        output_dict["mask"] = mask

        if tags is not None:
            seq2seq_loss = sequence_cross_entropy_with_logits(
                logits, tags, mask, label_smoothing=self._label_smoothing)
            # this is the integrated loss
            #im_sent_loss = self.vse_loss(final_states, image_embedding_resized)

            #loss =  seq2seq_loss * (1 - self.lamb) + im_sent_loss.sum() * self.lamb
            loss = seq2seq_loss
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [
                    example_metadata["verb_index"]
                    for example_metadata in metadata
                ]
                batch_sentences = [
                    example_metadata["words"] for example_metadata in metadata
                ]
                # Get the BIO tags from decode()
                # TODO (nfliu): This is kind of a hack, consider splitting out part
                # of decode() to a separate function.
                batch_bio_predicted_tags = self.decode(output_dict).pop("tags")
                # import ipdb; ipdb.set_trace()
                batch_conll_predicted_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_predicted_tags
                ]
                batch_bio_gold_tags = [
                    example_metadata["gold_tags"]
                    for example_metadata in metadata
                ]
                batch_conll_gold_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_gold_tags
                ]
                self.span_metric(
                    batch_verb_indices,
                    batch_sentences,
                    batch_conll_predicted_tags,
                    batch_conll_gold_tags,
                )
            output_dict["loss"] = loss

        words, verbs = zip(*[(x["words"], x["verb"]) for x in metadata])
        if metadata is not None:
            output_dict["words"] = list(words)
            output_dict["verb"] = list(verbs)
        return output_dict
Exemplo n.º 8
0
    def test_srl_eval_correctly_scores_identical_tags(self):
        batch_verb_indices = [3, 8, 2]
        batch_sentences = [[
            "Mali", "government", "officials", "say", "the", "woman", "'s",
            "confession", "was", "forced", "."
        ],
                           [
                               "Mali", "government", "officials", "say", "the",
                               "woman", "'s", "confession", "was", "forced",
                               "."
                           ],
                           [
                               'The', 'prosecution', 'rested', 'its', 'case',
                               'last', 'month', 'after', 'four', 'months',
                               'of', 'hearings', '.'
                           ]]
        batch_bio_predicted_tags = [[
            'B-ARG0', 'I-ARG0', 'I-ARG0', 'B-V', 'B-ARG1', 'I-ARG1', 'I-ARG1',
            'I-ARG1', 'I-ARG1', 'I-ARG1', 'O'
        ],
                                    [
                                        'O', 'O', 'O', 'O', 'B-ARG1', 'I-ARG1',
                                        'I-ARG1', 'I-ARG1', 'B-V', 'B-ARG2',
                                        'O'
                                    ],
                                    [
                                        'B-ARG0', 'I-ARG0', 'B-V', 'B-ARG1',
                                        'I-ARG1', 'B-ARGM-TMP', 'I-ARGM-TMP',
                                        'B-ARGM-TMP', 'I-ARGM-TMP',
                                        'I-ARGM-TMP', 'I-ARGM-TMP',
                                        'I-ARGM-TMP', 'O'
                                    ]]
        batch_conll_predicted_tags = [
            convert_bio_tags_to_conll_format(tags)
            for tags in batch_bio_predicted_tags
        ]
        batch_bio_gold_tags = [[
            'B-ARG0', 'I-ARG0', 'I-ARG0', 'B-V', 'B-ARG1', 'I-ARG1', 'I-ARG1',
            'I-ARG1', 'I-ARG1', 'I-ARG1', 'O'
        ],
                               [
                                   'O', 'O', 'O', 'O', 'B-ARG1', 'I-ARG1',
                                   'I-ARG1', 'I-ARG1', 'B-V', 'B-ARG2', 'O'
                               ],
                               [
                                   'B-ARG0', 'I-ARG0', 'B-V', 'B-ARG1',
                                   'I-ARG1', 'B-ARGM-TMP', 'I-ARGM-TMP',
                                   'B-ARGM-TMP', 'I-ARGM-TMP', 'I-ARGM-TMP',
                                   'I-ARGM-TMP', 'I-ARGM-TMP', 'O'
                               ]]
        batch_conll_gold_tags = [
            convert_bio_tags_to_conll_format(tags)
            for tags in batch_bio_gold_tags
        ]

        srl_scorer = SrlEvalScorer(ignore_classes=["V"])
        srl_scorer(batch_verb_indices, batch_sentences,
                   batch_conll_predicted_tags, batch_conll_gold_tags)
        metrics = srl_scorer.get_metric()
        assert len(metrics) == 15
        assert_allclose(metrics['precision-ARG0'], 1.0)
        assert_allclose(metrics['recall-ARG0'], 1.0)
        assert_allclose(metrics['f1-measure-ARG0'], 1.0)
        assert_allclose(metrics['precision-ARG1'], 1.0)
        assert_allclose(metrics['recall-ARG1'], 1.0)
        assert_allclose(metrics['f1-measure-ARG1'], 1.0)
        assert_allclose(metrics['precision-ARG2'], 1.0)
        assert_allclose(metrics['recall-ARG2'], 1.0)
        assert_allclose(metrics['f1-measure-ARG2'], 1.0)
        assert_allclose(metrics['precision-ARGM-TMP'], 1.0)
        assert_allclose(metrics['recall-ARGM-TMP'], 1.0)
        assert_allclose(metrics['f1-measure-ARGM-TMP'], 1.0)
        assert_allclose(metrics['precision-overall'], 1.0)
        assert_allclose(metrics['recall-overall'], 1.0)
        assert_allclose(metrics['f1-measure-overall'], 1.0)