Exemple #1
0
    def forward(self,  # type: ignore
                tokens: TextFieldTensors,
                spans: torch.IntTensor,
                ner_labels: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        mask = util.get_text_field_mask(tokens)
        span_mask = (spans[:, :, 0] >= 0)
        sentence_lengths = mask.sum(dim=1).long()

        embedded_text_input = self._text_field_embedder(tokens)
        span_embeddings = self._span_extractor(embedded_text_input, spans, span_mask)

        # cls_h = embedded_text_input[:, 0, :].unsqueeze(1).repeat(1, span_embeddings.shape[1], 1)

        # span_vectors = torch.cat((span_embeddings, cls_h), dim=2)

        span_vectors = span_embeddings

        # 32, 90, 1000
        ner_scores = self._ner_scorer(span_vectors)
        # Give large negative scores to masked-out elements.
        mask = span_mask.unsqueeze(-1)
        ner_scores = util.replace_masked_values(ner_scores, mask.bool(), -1e20)
        # The dummy_scores are the score for the null label.
        # dummy_dims = [ner_scores.size(0), ner_scores.size(1), 1]
        # dummy_scores = ner_scores.new_zeros(*dummy_dims)
        # ner_scores_logits = torch.cat((dummy_scores, ner_scores), -1)

        ner_scores_probs = torch.sigmoid(ner_scores)

        # predictions = self.predict(ner_scores.detach().cpu(),
        #                            spans.detach().cpu(),
        #                            span_mask.detach().cpu(),
        #                            metadata)
        #
        # output_dict = {"predictions": predictions}

        relations = self.extract_relations(spans, ner_scores_probs)

        output_dict = {"relations": relations}

        if ner_labels is not None:
            self._ner_metric(ner_scores_probs, ner_labels, span_mask)

            self._relation_metric(relations, [m["relations"] for m in metadata])

            loss = self._loss(ner_scores, ner_labels.float())

            output_dict["loss"] = loss

        return output_dict
Exemple #2
0
    def forward(self,  # type: ignore
                text: Dict[str, torch.LongTensor],
                spans: torch.IntTensor,
                labels: torch.IntTensor,
                metadata: List[Dict[str, Any]] = None, **kwargs) -> Dict[str, torch.Tensor]:
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(self._text_field_embedder(text))

        # Shape: (batch_size, document_length)
        span_mask = spans[:, :, 0] >= 0
        spans = F.relu(spans.float()).long()
        span_embeddings = self._span_extractor(text_embeddings, spans, span_indices_mask=span_mask)
        # endpoint_span_embeddings = self._endpoint_span_extractor(text_embeddings, spans)
        # attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans)

        # span_embeddings = torch.cat([pooling_span_embeddings, endpoint_span_embeddings, attended_span_embeddings], -1)

        # Shape: (batch_size, num_spans, class_num)
        ne_scores = self._entity_scorer(span_embeddings)

        # Shape: (batch_size, num_spans)
        _, predicted_named_entities = ne_scores.max(2)

        output_dict = {
            "predicted_named_entities": predicted_named_entities}
        if labels is not None:
            ne_scores = ne_scores.reshape(-1, self.class_num)
            labels = labels.reshape(-1)
            span_mask = span_mask.reshape(-1)

            if self.sampling_rate > 0:
                neg_mask = labels == 0
                neg_sampling_mask = neg_mask & span_mask & (
                        torch.rand(labels.shape, device=labels.device) < self.sampling_rate)
                sampling_mask = neg_sampling_mask | ~neg_mask
                negative_log_likelihood = F.cross_entropy(ne_scores[sampling_mask], labels[sampling_mask])
            else:
                negative_log_likelihood = F.cross_entropy(ne_scores, labels)
            output_dict["loss"] = negative_log_likelihood

            self._metric_all(ne_scores, labels.reshape(-1), span_mask)
            self._metric_avg(ne_scores, labels.reshape(-1), span_mask)

        return output_dict
Exemple #3
0
    def forward(
        self,  # type: ignore
        text_embeddings: torch.FloatTensor,
        text_mask: torch.IntTensor,
        ner_labels: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:

        # Shape: (Batch_size, Number of spans, H)
        span_feedforward = self._mention_feedforward(text_embeddings)
        ner_scores = self._ner_scorer(span_feedforward)
        predicted_ner = self._ner_crf.viterbi_tags(ner_scores, text_mask)

        predicted_ner = [x for x, y in predicted_ner]
        gold_ner = [list(x[m.bool()].detach().cpu().numpy()) for x, m in zip(ner_labels, text_mask)]

        output = {"logits": ner_scores, "tags": predicted_ner, "gold_tags": gold_ner}

        if ner_labels is not None:
            # Add negative log-likelihood as loss
            log_likelihood = self._ner_crf(ner_scores, ner_labels, text_mask)
            output["loss"] = -log_likelihood / text_embeddings.shape[0]

            # Represent viterbi tags as "class probabilities" that we can
            # feed into the metrics
            class_probabilities = ner_scores * 0.0
            for i, instance_tags in enumerate(predicted_ner):
                for j, tag_id in enumerate(instance_tags):
                    if i >= ner_scores.shape[0] or j >= ner_scores.shape[1] or tag_id >= ner_scores.shape[2]:
                        breakpoint()
                    class_probabilities[i, j, tag_id] = 1

            self._ner_metrics(class_probabilities, ner_labels, text_mask.float())

        if metadata is not None:
            output["metadata"] = metadata

        return output
Exemple #4
0
    def forward(
            self,  # type: ignore
            combined_source: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        # pylint: disable=arguments-differ
        """

        :param combined_source:
        :param label:
        :param metadata:
        :return:
        """
        embedded_source = self._text_field_embedder(
            combined_source)  # B * T * H
        source_mask = get_text_field_mask(combined_source)  # B * T
        embedded_source = self._variational_dropout(embedded_source)

        pooled = self._attn_pool(embedded_source, source_mask)  # B * H
        choice_score = self._output_ffl(pooled)  # B * 1

        output = torch.sigmoid(choice_score).squeeze(-1)  # B

        output_dict = {
            "label_logits": choice_score.squeeze(-1),
            "label_probs": output,
            "metadata": metadata
        }

        if label is not None:
            label = label.long().view(-1)
            loss = self._loss(output, label.float())
            self._auc(output, label)
            output_dict["loss"] = loss

        return output_dict
Exemple #5
0
    def forward(self,  # type: ignore
                text: Dict[str, torch.LongTensor],
                spans: torch.IntTensor,
                span_labels: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        text : ``Dict[str, torch.LongTensor]``, required.
            The output of a ``TextField`` representing the text of
            the document.
        spans : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        span_labels : ``torch.IntTensor``, optional (default = None)
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.

        Returns
        -------
        An output dictionary consisting of:
        top_spans : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : ``torch.IntTensor``
            A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(self._text_field_embedder(text))

        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(math.floor(self._spans_per_word * document_length))

        (top_span_embeddings, top_span_mask,
         top_span_indices, top_span_mention_scores) = self._mention_pruner(span_embeddings,
                                                                           span_mask,
                                                                           num_spans_to_keep)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans,
                                              top_span_indices,
                                              flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        #  index to the indices of its allowed antecedents. Note that this is independent
        #  of the batch dimension - it's just a function of the span's position in
        # top_spans. The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        #  like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        #  we can use to make coreference decisions between valid span pairs.

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(top_span_embeddings,
                                                                      valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(top_span_mention_scores,
                                                                          valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(top_span_embeddings,
                                                                  candidate_antecedent_embeddings,
                                                                  valid_antecedent_offsets)
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(span_pair_embeddings,
                                                              top_span_mention_scores,
                                                              candidate_antecedent_mention_scores,
                                                              valid_antecedent_log_mask)

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = {"top_spans": top_spans,
                       "antecedent_indices": valid_antecedent_indices,
                       "predicted_antecedents": predicted_antecedents}
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(span_labels.unsqueeze(-1),
                                                           top_span_indices,
                                                           flat_top_span_indices)

            antecedent_labels = util.flattened_index_select(pruned_gold_labels,
                                                            valid_antecedent_indices).squeeze(-1)
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(pruned_gold_labels,
                                                                          antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability assigned to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log()
            negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata)

            output_dict["loss"] = negative_marginal_log_likelihood

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]
        return output_dict
Exemple #6
0
    def forward(self,
                response: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                original_post: Optional[Dict[str, torch.LongTensor]] = None,
                weakpoints: Optional[torch.IntTensor] = None,
                fake_data: bool = False) -> Dict[str, torch.Tensor]:

        #print(original_post)
        #print(response)
        #print('label', label)
        '''
        print('LABEL', label[0])
        for key in original_post:
            print('ORIGINAL POST')            
            for i in range(original_post[key][0].size(0)):
                o = [self.vocab.get_token_from_index(int(index), key) for index in original_post[key][0][i] if int(index)]
                if len(o):
                    print(o)
            print('RESPONSE') 
            for i in range(response[key][0].size(0)):
                o = [self.vocab.get_token_from_index(int(index), key) for index in response[key][0][i] if int(index)]
                if len(o):
                    print(o)
        '''

        embedded_response = self._response_embedder(response,
                                                    num_wrapping_dims=1)

        #print(embedded_op.shape, embedded_response.shape)
        batch_size, max_response_sentences, max_response_words, response_dim = embedded_response.shape

        response_mask = get_text_field_mask(response,
                                            num_wrapping_dims=1).float()

        #get weighted average of words in sentence
        embedded_response = embedded_response.view(
            batch_size * max_response_sentences, max_response_words, -1)
        response_mask = response_mask.view(batch_size * max_response_sentences,
                                           max_response_words)

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_response = self.rnn_input_dropout(embedded_response)

        #print(embedded_op.shape, op_mask.shape, embedded_response.shape, response_mask.shape)

        response_attention = self._response_word_attention(
            embedded_response, response_mask)
        embedded_response = weighted_sum(embedded_response,
                                         response_attention).view(
                                             batch_size,
                                             max_response_sentences, -1)

        response_mask = response_mask.view(batch_size, max_response_sentences,
                                           -1).sum(dim=-1) > 0

        #print(embedded_op.shape, op_mask.shape, embedded_response.shape, response_mask.shape)
        # encode OP and response at sentence level
        encoded_response = self._response_encoder(embedded_response,
                                                  response_mask)

        if original_post is not None:
            embedded_op = self._op_embedder(original_post, num_wrapping_dims=1)
            _, max_op_sentences, max_op_words, op_dim = embedded_op.shape
            op_mask = get_text_field_mask(original_post,
                                          num_wrapping_dims=1).float()

            embedded_op = embedded_op.view(batch_size * max_op_sentences,
                                           max_op_words, -1)
            op_mask = op_mask.view(batch_size * max_op_sentences, max_op_words)

            # apply dropout for LSTM
            if self.rnn_input_dropout:
                embedded_op = self.rnn_input_dropout(embedded_op)

            op_attention = self._op_word_attention(embedded_op, op_mask)
            embedded_op = weighted_sum(embedded_op, op_attention).view(
                batch_size, max_op_sentences, -1)

            op_mask = op_mask.view(batch_size, max_op_sentences,
                                   -1).sum(dim=-1) > 0

            encoded_op = self._op_encoder(embedded_op, op_mask)

            combined_input = self._response_sentence_attention(
                encoded_op, encoded_response, op_mask, response_mask,
                self._op_sentence_attention)

        else:
            attn = self._op_sentence_attention(encoded_response, response_mask)
            combined_input = weighted_sum(encoded_response, attn)

        #now batch_size x n_dim
        #encoded_op = self._op_sentence_attention(encoded_op, op_mask)

        if self.dropout:
            combined_input = self.dropout(combined_input)

        label_logits = self._output_feedforward(combined_input).squeeze(-1)
        label_probs = torch.sigmoid(label_logits)
        #print(label_probs)
        predictions = label_probs > 0.5
        #print('predictions', predictions)
        #print('1-predictions', 1-predictions)
        #print(label)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs,
            "representation": combined_input
        }

        true_weight = (label == 0).sum().float() / (label == 1).sum().float()
        print(true_weight)

        weight = label.eq(0).float() + label.eq(1).float() * true_weight
        loss = self._loss(label_logits, label.float(), weight=weight)

        if fake_data:
            self._fake_accuracy(predictions, label.byte())
            self._fake_fscore(
                torch.stack([1 - predictions, predictions], dim=1), label)
        else:
            self._accuracy(predictions, label.byte())
            #self._cat_accuracy(torch.stack([1-predictions, predictions], dim=1), label.byte())
            self._fscore(torch.stack([1 - predictions, predictions], dim=1),
                         label)

        output_dict["loss"] = loss

        return output_dict
Exemple #7
0
    def forward(
            self,  # type: ignore
            text: Dict[str, torch.LongTensor],
            spans: torch.IntTensor,
            labels: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None,
            **kwargs) -> Dict[str, torch.Tensor]:
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(
            self._text_field_embedder(text))

        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            text_embeddings, text_mask)

        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(
            contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat(
            [endpoint_span_embeddings, attended_span_embeddings], -1)
        # span_embeddings = self._span_extractor(text_embeddings, spans, span_indices_mask=span_mask)

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * document_length))
        num_spans_to_keep = min(num_spans_to_keep, span_embeddings.shape[1])

        # Shape:    (batch_size, num_spans_to_keep, emebedding_size + 2 * encoding_dim + feature_size)
        #           (batch_size, num_spans_to_keep)
        #           (batch_size, num_spans_to_keep)
        #           (batch_size, num_spans_to_keep, 1)
        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = self._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)
        # (batch_size, num_spans_to_keep, 1)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)

        # Shape: (batch_size, num_spans_to_keep, class_num + 1)
        ne_scores = self._compute_named_entity_scores(top_span_embeddings)

        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_named_entities = ne_scores.max(2)

        output_dict = {
            "top_spans": top_spans,
            "predicted_named_entities": predicted_named_entities
        }
        if labels is not None:
            # Find the gold labels for the spans which we kept.
            # Shape: (batch_size, num_spans_to_keep, 1)
            pruned_gold_labels = util.batched_index_select(
                labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices).squeeze(-1)
            negative_log_likelihood = F.cross_entropy(
                ne_scores.reshape(-1, self.class_num),
                pruned_gold_labels.reshape(-1))

            pruner_loss = F.binary_cross_entropy_with_logits(
                top_span_mention_scores.reshape(-1),
                (pruned_gold_labels.reshape(-1) != 0).float())
            loss = negative_log_likelihood + pruner_loss
            output_dict["loss"] = loss
            output_dict["pruner_loss"] = pruner_loss
            batch_size, _ = labels.shape
            all_scores = ne_scores.new_zeros(
                [batch_size * num_spans, self.class_num])
            all_scores[:, 0] = 1
            all_scores[flat_top_span_indices] = ne_scores.reshape(
                -1, self.class_num)
            all_scores = all_scores.reshape(
                [batch_size, num_spans, self.class_num])
            self._metric_all(all_scores, labels)
            self._metric_avg(all_scores, labels)
        return output_dict
    def forward(
            self,  # type: ignore
            passage: Dict[str, torch.LongTensor],
            all_qa: Dict[str, torch.LongTensor],
            candidate: Dict[str, torch.LongTensor],
            combined_source: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """"""
        if self._with_knowledge:
            embedded_passage = self._text_field_embedder(passage)  # B * T * d
            passage_len = embedded_passage.size(1)
        embedded_all_qa = self._text_field_embedder(all_qa)  # B * U * d
        embedded_choice = self._text_field_embedder(candidate)  # B * V * d

        if self._with_knowledge:
            embedded_passage = self._variational_dropout(
                embedded_passage)  # B * T * d
        embedded_all_qa = self._variational_dropout(embedded_all_qa)
        embedded_choice = self._variational_dropout(
            embedded_choice)  # B * V * d

        all_qa_mask = util.get_text_field_mask(all_qa)  # B * U
        choice_mask = util.get_text_field_mask(candidate)  # B * V

        # Encoding
        if self._with_knowledge:
            # B * T * H
            passage_mask = util.get_text_field_mask(passage)  # B * T
            encoded_passage = self._variational_dropout(
                self._pseqlevel_enc(embedded_passage, passage_mask))
        # B * U * H
        if self._shared_rnn:
            encoded_allqa = self._variational_dropout(
                self._pseqlevel_enc(embedded_all_qa, all_qa_mask))
        else:
            encoded_allqa = self._variational_dropout(
                self._qaseqlevel_enc(embedded_all_qa, all_qa_mask))

        if self._with_knowledge and self._is_qdep_penc:
            # similarity matrix
            _, normalized_attn_mat = self._cart_attn(encoded_passage,
                                                     encoded_allqa,
                                                     all_qa_mask)  # B * T * U
            # question dependent passage encoding
            q_aware_passage_rep = sequential_weighted_avg(
                encoded_allqa, normalized_attn_mat)  # B * T * H

            q_dep_passage_enc_rnn_input = torch.cat(
                [encoded_passage, q_aware_passage_rep], 2)  # B * T * 2H

            # gated question dependent passage encoding
            gated_qaware_passage_rep = self._gate_qdep_penc(
                q_dep_passage_enc_rnn_input)  # B * T * 2H
            encoded_qdep_penc = self._qdep_penc_rnn(gated_qaware_passage_rep,
                                                    passage_mask)  # B * T * H

        # multi factor attentive encoding
        if self._with_knowledge and self._is_mfa_enc:
            if self._is_qdep_penc:
                mfa_enc = self._multifactor_attn(encoded_qdep_penc,
                                                 passage_mask)  # B * T * 2H
            else:
                mfa_enc = self._multifactor_attn(encoded_passage,
                                                 passage_mask)  # B * T * 2H
            encoded_passage = self._mfarnn(mfa_enc, passage_mask)  # B * T * H

        # B * V * H
        if self._shared_rnn:
            encoded_choice = self._variational_dropout(
                self._pseqlevel_enc(embedded_choice, choice_mask))  # B * V * H
        else:
            encoded_choice = self._variational_dropout(
                self._cseqlevel_enc(embedded_choice, choice_mask))  # B * V * H

        if self._with_knowledge:
            attn_pq, _ = self._pqaattnmat(encoded_passage, encoded_allqa,
                                          all_qa_mask)  # B * T * U
            combined_pqa_mask = passage_mask.unsqueeze(-1) * \
                                all_qa_mask.unsqueeze(1)  # B * T * U
            max_attn_pqa = masked_max(attn_pq, combined_pqa_mask,
                                      dim=1)  # B * U
            norm_attn_pqa = masked_softmax(max_attn_pqa, all_qa_mask,
                                           dim=-1)  # B * U
            agg_prev_qa = norm_attn_pqa.unsqueeze(1).bmm(
                encoded_allqa).squeeze(1)  # B * H

            attn_pc, _ = self._pcattnmat(encoded_passage, encoded_choice,
                                         choice_mask)  # B * T * V
            combined_pc_mask = passage_mask.unsqueeze(-1) * \
                               choice_mask.unsqueeze(1)  # B * T * V
            max_attn_pc = masked_max(attn_pc, combined_pc_mask, dim=1)  # B * V
            norm_attn_pc = masked_softmax(max_attn_pc, choice_mask,
                                          dim=-1)  # B * V
            agg_c = norm_attn_pc.unsqueeze(1).bmm(encoded_choice)  # B * 1 * H

            choice_scores_wk = agg_c.bmm(agg_prev_qa.unsqueeze(-1)).squeeze(
                -1)  # B * 1

        if self._qac_ap:
            attn_qac, _ = self._cqaattnmat(encoded_allqa, encoded_choice,
                                           choice_mask)  # B * U * V
            combined_qac_mask = all_qa_mask.unsqueeze(-1) * \
                                choice_mask.unsqueeze(1)  # B * U * V

            max_attn_c = masked_max(attn_qac, combined_qac_mask,
                                    dim=1)  # B * V
            max_attn_qa = masked_max(attn_qac, combined_qac_mask,
                                     dim=2)  # B * U
            norm_attn_c = masked_softmax(max_attn_c, choice_mask,
                                         dim=-1)  # B * V
            norm_attn_qa = masked_softmax(max_attn_qa, all_qa_mask,
                                          dim=-1)  # B * U
            agg_c_qa = norm_attn_c.unsqueeze(1).bmm(encoded_choice).squeeze(
                1)  # B * H
            agg_qa_c = norm_attn_qa.unsqueeze(1).bmm(encoded_allqa).squeeze(
                1)  # B * H

            choice_scores_nk = agg_c_qa.unsqueeze(1).bmm(
                agg_qa_c.unsqueeze(-1)).squeeze(-1)  # B * 1

        if self._with_knowledge and self._qac_ap:
            choice_score = choice_scores_wk + choice_scores_nk
        elif self._qac_ap:
            choice_score = choice_scores_nk
        elif self._with_knowledge:
            choice_score = choice_scores_wk
        else:
            raise NotImplementedError

        output = torch.sigmoid(choice_score).squeeze(-1)  # B

        output_dict = {
            "label_logits": choice_score.squeeze(-1),
            "label_probs": output,
            "metadata": metadata
        }

        if label is not None:
            label = label.long().view(-1)
            loss = self._loss(output, label.float())
            self._auc(output, label)
            output_dict["loss"] = loss

        return output_dict
Exemple #9
0
    def forward(
            self,  # type: ignore
            trigger_spans: torch.IntTensor,
            trigger_mask: torch.IntTensor,
            trigger_embeddings: torch.Tensor,
            spans: torch.IntTensor,
            span_mask: torch.IntTensor,
            span_embeddings: torch.
        Tensor,  # TODO: make sure types are init correctly
            text_mask: torch.IntTensor,
            text_embeddings: torch.Tensor,
            sentence_lengths: torch.IntTensor,
            trigger_labels: torch.IntTensor,
            argument_labels: torch.IntTensor,
            ner_labels: torch.IntTensor,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        """
        The trigger embeddings are just the contextualized token embeddings, and the trigger mask is
        the text mask. For the arguments, we consider all the spans.
        """
        self._active_dataset = metadata.dataset
        self._active_namespaces = {
            "trigger": f"{self._active_dataset}__trigger_labels",
            "argument": f"{self._active_dataset}__argument_labels"
        }

        # Compute trigger scores.
        trigger_scores = self._compute_trigger_scores(trigger_embeddings,
                                                      trigger_mask)

        # Get trigger candidates for event argument labeling.
        num_trigs_to_keep = torch.floor(sentence_lengths.float() *
                                        self._trigger_spans_per_word).long()
        num_trigs_to_keep = torch.max(num_trigs_to_keep,
                                      torch.ones_like(num_trigs_to_keep))
        num_trigs_to_keep = torch.min(num_trigs_to_keep,
                                      15 * torch.ones_like(num_trigs_to_keep))

        trigger_pruner = self._trigger_pruners[
            self._active_namespaces["trigger"]]
        (top_trig_embeddings, top_trig_mask, top_trig_indices, top_trig_scores,
         num_trigs_kept) = trigger_pruner(trigger_embeddings, trigger_mask,
                                          num_trigs_to_keep, trigger_scores)

        top_trig_spans = util.batched_index_select(trigger_spans,
                                                   top_trig_indices)
        top_trig_mask = top_trig_mask.unsqueeze(-1)

        # Compute the number of argument spans to keep.
        num_arg_spans_to_keep = torch.floor(
            sentence_lengths.float() * self._argument_spans_per_word).long()
        num_arg_spans_to_keep = torch.max(
            num_arg_spans_to_keep, torch.ones_like(num_arg_spans_to_keep))
        num_arg_spans_to_keep = torch.min(
            num_arg_spans_to_keep, 30 * torch.ones_like(num_arg_spans_to_keep))

        # If we're using gold event arguments, include the gold labels.
        mention_pruner = self._mention_pruners[
            self._active_namespaces["argument"]]
        gold_labels = None
        (top_arg_embeddings, top_arg_mask, top_arg_indices, top_arg_scores,
         num_arg_spans_kept) = mention_pruner(span_embeddings, span_mask,
                                              num_arg_spans_to_keep,
                                              gold_labels)

        top_arg_mask = top_arg_mask.unsqueeze(-1)
        top_arg_spans = util.batched_index_select(spans, top_arg_indices)

        # Compute trigger / argument pair embeddings.
        trig_arg_embeddings = self._compute_trig_arg_embeddings(
            top_trig_embeddings, top_arg_embeddings, top_trig_spans,
            top_arg_spans, text_mask, text_embeddings)
        argument_scores = self._compute_argument_scores(
            trig_arg_embeddings, top_trig_scores, top_arg_scores, top_arg_mask)

        # Assemble inputs to do prediction.
        output_dict = {
            "top_trigger_spans": top_trig_spans,
            "top_argument_spans": top_arg_spans,
            "trigger_scores": trigger_scores,
            "argument_scores": argument_scores,
            "num_triggers_kept": num_trigs_kept,
            "num_argument_spans_kept": num_arg_spans_kept,
            "trigger_mask": trigger_mask,
            "trigger_spans": trigger_spans,
            "sentence_lengths": sentence_lengths
        }

        prediction_dicts, predictions = self.predict(output_dict, metadata)

        output_dict = {"predictions": predictions}

        # Evaluate loss and F1 if labels were provided.
        if trigger_labels is not None and argument_labels is not None:
            # Compute the loss for both triggers and arguments.
            trigger_loss = self._get_trigger_loss(trigger_scores,
                                                  trigger_labels, trigger_mask)

            gold_arguments = self._get_pruned_gold_arguments(
                argument_labels, top_trig_indices, top_arg_indices,
                top_trig_mask, top_arg_mask)

            argument_loss = self._get_argument_loss(argument_scores,
                                                    gold_arguments)

            # Compute F1.
            assert len(prediction_dicts) == len(
                metadata)  # Make sure length of predictions is right.

            # Compute metrics for this label namespace.
            metrics = self._metrics[self._active_dataset]
            metrics(prediction_dicts, metadata)

            loss = (self._loss_weights["trigger"] * trigger_loss +
                    self._loss_weights["arguments"] * argument_loss)

            output_dict["loss"] = loss

        return output_dict
Exemple #10
0
def pack_obs(state: Tensor, time: IntTensor) -> Tensor:
    """Reverses the `unpack_obs` transformation."""
    return torch.cat((state, time.float()), dim="R")
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            sentence_spans: torch.IntTensor = None,
            sent_labels: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None,
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            q_type: torch.IntTensor = None,
            sp_mask: torch.IntTensor = None,
            coref_mask: torch.FloatTensor = None) -> Dict[str, torch.Tensor]:

        embedded_question = self._text_field_embedder(question)
        embedded_passage = self._text_field_embedder(passage)
        decoupled_passage, spans_mask = convert_sequence_to_spans(
            embedded_passage, sentence_spans)
        batch_size, num_spans, max_batch_span_width = spans_mask.size()
        encodeded_decoupled_passage = \
            self._phrase_layer_sp(
                decoupled_passage, spans_mask.view(batch_size * num_spans, -1))
        context_output_sp = convert_span_to_sequence(
            embedded_passage, encodeded_decoupled_passage, spans_mask)

        ques_mask = util.get_text_field_mask(question).float()
        context_mask = util.get_text_field_mask(passage).float()

        ques_output_sp = self._phrase_layer_sp(embedded_question, ques_mask)

        modeled_passage_sp = self.qc_att_sp(context_output_sp, ques_output_sp,
                                            ques_mask)
        modeled_passage_sp = self.linear_2(modeled_passage_sp)
        modeled_passage_sp = self._modeling_layer_sp(modeled_passage_sp,
                                                     context_mask)
        # Shape(spans_rep): (batch_size * num_spans, max_batch_span_width, embedding_dim)
        # Shape(spans_mask): (batch_size, num_spans, max_batch_span_width)
        spans_rep_sp, spans_mask = convert_sequence_to_spans(
            modeled_passage_sp, sentence_spans)
        # Shape(gate_logit): (batch_size * num_spans, 2)
        # Shape(gate): (batch_size * num_spans, 1)
        # Shape(pred_sent_probs): (batch_size * num_spans, 2)
        gate_logit = self._span_gate(spans_rep_sp, spans_mask)
        batch_size, num_spans, max_batch_span_width = spans_mask.size()
        sent_mask = (sent_labels >= 0).long()
        sent_labels = sent_labels * sent_mask
        # print(sent_labels)
        # print(gate_logit.shape)
        # print(gate_logit)
        strong_sup_loss = torch.mean(-torch.log(
            torch.sum(F.softmax(gate_logit) *
                      sent_labels.float().view(batch_size, num_spans),
                      dim=-1) + 1e-10))

        # strong_sup_loss = F.nll_loss(F.log_softmax(gate_logit, dim=-1).view(batch_size * num_spans, -1),
        #                              sent_labels.long().view(batch_size * num_spans), ignore_index=-1)

        gate = torch.argmax(gate_logit.view(batch_size, num_spans), -1)
        # gate = (gate >= 0.5).long().view(batch_size, num_spans)
        output_dict = {"gate": gate}

        loss = strong_sup_loss

        output_dict["loss"] = loss

        if metadata is not None:
            question_tokens = []
            passage_tokens = []
            sent_labels_list = []
            ids = []

            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                sent_labels_list.append(metadata[i]['sent_labels'])
                ids.append(metadata[i]['_id'])

            self._sent_metrics(gate, sent_labels)
            # print(self.get_prediction(gate, sent_labels).item())
            # print(self.get_prediction(gate, sent_labels).data)
            output_dict['predict'] = [
                self.get_prediction(gate, sent_labels).data
            ]
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
            output_dict['sent_labels'] = sent_labels_list
            output_dict['_id'] = ids
            # print(ids)

        return output_dict
Exemple #12
0
    def forward(
        self,  # type: ignore
        source_spans: torch.IntTensor,
        source_tokens: Dict[str, torch.LongTensor],
        target_tokens: Dict[str, torch.LongTensor] = None
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing the entire target sequence.

        Parameters
        ----------
        source_spans : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for source sentence representation.
            Comes from a ``ListField[SpanField]`` of indices into the source sentence.
        source_tokens : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the source ``TextField``. This will be
           passed through a ``TextFieldEmbedder`` and then through an encoder.
        target_tokens : Dict[str, torch.LongTensor], optional (default = None)
           Output of ``Textfield.as_array()`` applied on target ``TextField``. We assume that the
           target tokens are also represented as a ``TextField``.
        """
        # (batch_size, input_sequence_length, encoder_output_dim)
        embedded_input = self._source_embedder(source_tokens)

        num_spans = source_spans.size(1)
        source_length = embedded_input.size(1)
        batch_size, _, _ = embedded_input.size()

        # (batch_size, source_length)
        source_mask = get_text_field_mask(source_tokens)

        # Shape: (batch_size, num_spans)
        span_mask = (source_spans[:, :, 0] >= 0).squeeze(-1).float()

        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(source_spans.float()).long()

        # Contextualized word embeddings; Shape: (batch_size, source_length, embedding_dim)
        contextualized_word_embeddings = self._encoder(embedded_input,
                                                       source_mask)

        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        span_embeddings = self._span_extractor(contextualized_word_embeddings,
                                               spans)

        # Prune based on feedforward scorer
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * source_length))

        # Shape: see return section of SpanPruner docs
        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_scores) = self._span_pruner(span_embeddings, span_mask,
                                              num_spans_to_keep)

        # Shape: (batch_size * num_spans_to_keep)
        flat_top_span_indices = flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = batched_index_select(spans, top_span_indices,
                                         flat_top_span_indices)

        # Here we define what we will init first hidden state of decoder with
        summary_of_encoded_source = contextualized_word_embeddings[:,
                                                                   -1]  # (batch_size, encoder_output_dim)
        if target_tokens:
            targets = target_tokens["tokens"]
            target_sequence_length = targets.size()[1]
            # The last input from the target is either padding or the end symbol. Either way, we
            # don't have to process it.
            num_decoding_steps = target_sequence_length - 1
        else:
            num_decoding_steps = self._max_decoding_steps

        # Condition decoder on encoder
        # Here we just derive and append one more dummy embedding feature (sum) to match dimensions later
        # Shape: (batch_size, encoder_output_dim + 1)
        decoder_hidden = torch.cat(
            (summary_of_encoded_source,
             summary_of_encoded_source.sum(1).unsqueeze(1)), 1)
        decoder_context = Variable(top_span_embeddings.data.new().resize_(
            batch_size, self._decoder_output_dim).fill_(0))
        last_predictions = None
        step_logits = []
        step_probabilities = []
        step_predictions = []
        step_attention_weights = []
        for timestep in range(num_decoding_steps):
            if self.training and all(
                    torch.rand(1) >= self._scheduled_sampling_ratio):
                input_choices = targets[:, timestep]
            else:
                if timestep == 0:
                    # For the first timestep, when we do not have targets, we input start symbols.
                    # (batch_size,)
                    input_choices = Variable(
                        source_mask.data.new().resize_(batch_size).fill_(
                            self._start_index))
                else:
                    input_choices = last_predictions

            # We append span scores to the span embedding features to make SpanPrune trainable
            # Shape: (batch_size, num_spans_to_keep, span_embedding_dim + 1)
            top_span_embeddings_scores = torch.cat(
                (top_span_embeddings, top_span_scores), 2)
            # Shape: (batch_size, decoder_input_dim)
            decoder_input, attention_weights = self._prepare_decode_step_input(
                input_choices, decoder_hidden, top_span_embeddings_scores,
                top_span_mask)

            if attention_weights is not None:
                step_attention_weights.append(attention_weights)

            # Shape: both (batch_size, decoder_output_dim),
            decoder_hidden, decoder_context = self._decoder_cell(
                decoder_input, (decoder_hidden, decoder_context))
            # (batch_size, num_classes)
            output_projections = self._output_projection_layer(decoder_hidden)
            # list of (batch_size, 1, num_classes)
            step_logits.append(output_projections.unsqueeze(1))
            class_probabilities = F.softmax(output_projections, dim=-1)
            _, predicted_classes = torch.max(class_probabilities, 1)
            step_probabilities.append(class_probabilities.unsqueeze(1))
            last_predictions = predicted_classes
            # (batch_size, 1)
            step_predictions.append(last_predictions.unsqueeze(1))
        # step_logits is a list containing tensors of shape (batch_size, 1, num_classes)
        # This is (batch_size, num_decoding_steps, num_classes)
        logits = torch.cat(step_logits, 1)
        class_probabilities = torch.cat(step_probabilities, 1)
        all_predictions = torch.cat(step_predictions, 1)

        # step_attention_weights is a list containing tensors of shape (batch_size, num_encoder_outputs)
        # This is (batch_size, num_decoding_steps, num_encoder_outputs)
        if len(step_attention_weights) > 0:
            attention_matrix = torch.cat(step_attention_weights, 0)

        attention_matrix.unsqueeze_(0)
        output_dict = {
            "logits": logits,
            "class_probabilities": class_probabilities,
            "predictions": all_predictions,
            "top_spans": top_spans,
            "attention_matrix": attention_matrix,
            "top_spans_scores": top_span_scores
        }
        if target_tokens:
            target_mask = get_text_field_mask(target_tokens)
            loss = self._get_loss(logits, targets, target_mask)
            output_dict[
                "loss"] = loss  #+ top_span_scores.squeeze().view(-1).index_select(0, top_span_mask.view(-1).long()).sum()
        return output_dict