def _arithmetic_log_likelihood(self,
                                 answer_as_expressions,
                                 arithmetic_template_slot_log_probs, 
                                 arithmetic_template_log_probs):
     # answer_as_expressions : (batch, #templates, #expressions, #slots)
     # arithmetic_template_slot_log_probs : (batch, #templates, #slots, #numbers)
     # arithmetic_template_log_probs : (batch, #templates)
     
     # shape : (batch, #templates, #slots, #expressions)
     gold_templates = answer_as_expressions.transpose(2,3).long()
     
     # mask for invalid/padded expressions
     gold_templates_mask = (gold_templates[:,:,:,:] != -1).long()
     clamped_gold_templates = \
         util.replace_masked_values(gold_templates, gold_templates_mask, 0)
     
     # shape : (batch, #templates, #slots, #expressions)
     log_likelihood_per_slot = \
         torch.gather(arithmetic_template_slot_log_probs, -1, clamped_gold_templates)
     
     # shape : (batch, #templates, #expressions)
     log_likelihood_per_expression = log_likelihood_per_slot.sum(2)
     # mask out padded expressions
     log_likelihood_per_expression = util.replace_masked_values(log_likelihood_per_expression, 
                                                                gold_templates_mask[:,:,0,:], 
                                                                -1e7)
     # shape : (batch, #templates)
     log_likelihood_per_template = util.logsumexp(log_likelihood_per_expression)
     log_joint_likelihood_for_arithmetic = log_likelihood_per_template + arithmetic_template_log_probs
     
     # Shape: (batch_size, )
     log_marginal_likelihood_for_arithmetic = util.logsumexp(log_joint_likelihood_for_arithmetic)
     return log_marginal_likelihood_for_arithmetic
    def _base_arithmetic_module(self, passage_vector, passage_out, number_indices, number_mask):
        number_indices = number_indices[:,:,0].long()
        clamped_number_indices = util.replace_masked_values(number_indices, number_mask, 0)
        encoded_numbers = torch.gather(
                passage_out,
                1,
                clamped_number_indices.unsqueeze(-1).expand(-1, -1, passage_out.size(-1)))
            
        if self.num_special_numbers > 0:
            special_numbers = self.special_embedding(torch.arange(self.num_special_numbers, device=number_indices.device))
            special_numbers = special_numbers.expand(number_indices.shape[0],-1,-1)
            encoded_numbers = torch.cat([special_numbers, encoded_numbers], 1)
            
            mask = torch.ones((number_indices.shape[0],self.num_special_numbers), device=number_indices.device).long()
            number_mask = torch.cat([mask, number_mask], -1)
        # Shape: (batch_size, # of numbers, 2*bert_dim)
        encoded_numbers = torch.cat(
                [encoded_numbers, passage_vector.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1)], -1)

        # Shape: (batch_size, # of numbers in the passage, 3)
        number_sign_logits = self._number_sign_predictor(encoded_numbers)
        number_sign_log_probs = torch.nn.functional.log_softmax(number_sign_logits, -1)

        # Shape: (batch_size, # of numbers in passage).
        best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1)
        # For padding numbers, the best sign masked as 0 (not included).
        best_signs_for_numbers = util.replace_masked_values(best_signs_for_numbers, number_mask, 0)
        return number_sign_log_probs, best_signs_for_numbers, number_mask
Beispiel #3
0
 def _question_span_log_likelihood(self, answer_as_question_spans,
                                   question_span_start_log_probs,
                                   question_span_end_log_probs):
     # Shape: (batch_size, # of answer spans)
     gold_question_span_starts = answer_as_question_spans[:, :, 0]
     gold_question_span_ends = answer_as_question_spans[:, :, 1]
     # Some spans are padded with index -1,
     # so we clamp those paddings to 0 and then mask after `torch.gather()`.
     gold_question_span_mask = (gold_question_span_starts != -1).long()
     clamped_gold_question_span_starts = \
         util.replace_masked_values(gold_question_span_starts, gold_question_span_mask, 0)
     clamped_gold_question_span_ends = \
         util.replace_masked_values(gold_question_span_ends, gold_question_span_mask, 0)
     # Shape: (batch_size, # of answer spans)
     log_likelihood_for_question_span_starts = \
         torch.gather(question_span_start_log_probs, 1, clamped_gold_question_span_starts)
     log_likelihood_for_question_span_ends = \
         torch.gather(question_span_end_log_probs, 1, clamped_gold_question_span_ends)
     # Shape: (batch_size, # of answer spans)
     log_likelihood_for_question_spans = \
         log_likelihood_for_question_span_starts + log_likelihood_for_question_span_ends
     # For those padded spans, we set their log probabilities to be very small negative value
     log_likelihood_for_question_spans = \
         util.replace_masked_values(log_likelihood_for_question_spans,
                                    gold_question_span_mask,
                                    -1e7)
     # Shape: (batch_size, )
     log_marginal_likelihood_for_question_span = \
         util.logsumexp(log_likelihood_for_question_spans)
     return log_marginal_likelihood_for_question_span
Beispiel #4
0
    def forward_span(self,
                     ds_name,
                     dialog_repr,
                     repeated_ds_embeddings,
                     context_masks,
                     span_labels=None,
                     spans_start=None,
                     spans_end=None):
        batch_size, max_dialog_len = context_masks.size()
        ds_dialog_sim = self._ds_dialog_attention(
            self._dropout(repeated_ds_embeddings), self._dropout(dialog_repr))
        ds_dialog_att = util.masked_softmax(
            ds_dialog_sim.view(-1, max_dialog_len),
            context_masks.view(-1, max_dialog_len))
        ds_dialog_att = ds_dialog_att.view(batch_size, max_dialog_len)
        ds_dialog_repr = util.weighted_sum(dialog_repr, ds_dialog_att)
        ds_dialog_repr = ds_dialog_repr + repeated_ds_embeddings.squeeze(1)
        span_label_logits = self._span_label_predictor(
            F.relu(self._dropout(ds_dialog_repr)))
        span_label_prediction = torch.argmax(span_label_logits, dim=1)
        span_label_loss = 0.0
        if span_labels is not None:
            span_label_loss = self._cross_entropy(
                span_label_logits, span_labels)  # loss averaged by #turn
            self._accuracy.span_label_acc(ds_name, span_label_logits,
                                          span_labels, span_labels != -1)
        loss = span_label_loss

        w = self._span_prediction_layer(
            self._dropout(ds_dialog_repr)).unsqueeze(1)
        span_start_repr = self._span_start_encoder(self._dropout(dialog_repr))
        span_start_logits = torch.bmm(w,
                                      span_start_repr.transpose(1,
                                                                2)).squeeze(1)
        span_start_probs = util.masked_softmax(span_start_logits,
                                               context_masks)
        span_start_logits = util.replace_masked_values(
            span_start_logits, context_masks.to(dtype=torch.int8), -1e7)

        span_end_repr = self._span_end_encoder(self._dropout(span_start_repr))
        span_end_logits = torch.bmm(w, span_end_repr.transpose(1,
                                                               2)).squeeze(1)
        span_end_probs = util.masked_softmax(span_end_logits, context_masks)
        span_end_logits = util.replace_masked_values(
            span_end_logits, context_masks.to(dtype=torch.int8), -1e7)

        best_span = self.get_best_span(span_start_logits, span_end_logits)
        best_span = best_span.view(batch_size, -1)

        spans_loss = 0.0
        if spans_start is not None:
            spans_loss = self._cross_entropy(span_start_logits, spans_start)
            self._accuracy.span_start_acc(ds_name, span_start_logits,
                                          spans_start, spans_start != -1)
            spans_loss += self._cross_entropy(span_end_logits, spans_end)
            self._accuracy.span_end_acc(ds_name, span_end_logits, spans_end,
                                        spans_end != -1)
        loss += spans_loss

        return loss, (span_label_prediction, best_span)
Beispiel #5
0
 def _base_arithmetic_log_likelihood(self, answer_as_expressions,
                                     number_sign_log_probs, number_mask,
                                     answer_as_expressions_extra):
     if self.num_special_numbers > 0:
         answer_as_expressions = torch.cat(
             [answer_as_expressions_extra, answer_as_expressions], -1)
     # The padded add-sub combinations use 0 as the signs for all numbers, and we mask them here.
     # Shape: (batch_size, # of combinations)
     gold_add_sub_mask = (answer_as_expressions.sum(-1) > 0).float()
     # Shape: (batch_size, # of numbers in the passage, # of combinations)
     gold_add_sub_signs = answer_as_expressions.transpose(1, 2)
     # Shape: (batch_size, # of numbers in the passage, # of combinations)
     log_likelihood_for_number_signs = torch.gather(number_sign_log_probs,
                                                    2, gold_add_sub_signs)
     # the log likelihood of the masked positions should be 0
     # so that it will not affect the joint probability
     log_likelihood_for_number_signs = \
         util.replace_masked_values(log_likelihood_for_number_signs, number_mask.unsqueeze(-1), 0)
     # Shape: (batch_size, # of combinations)
     log_likelihood_for_add_subs = log_likelihood_for_number_signs.sum(1)
     # For those padded combinations, we set their log probabilities to be very small negative value
     log_likelihood_for_add_subs = \
         util.replace_masked_values(log_likelihood_for_add_subs, gold_add_sub_mask, -1e7)
     # Shape: (batch_size,)
     log_marginal_likelihood_for_add_sub = util.logsumexp(
         log_likelihood_for_add_subs)
     return log_marginal_likelihood_for_add_sub
Beispiel #6
0
    def _question_span_module(self, passage_vector, question_out,
                              question_mask):
        # Shape: (batch_size, question_length)
        encoded_question_for_span_prediction = \
            torch.cat([question_out,
                       passage_vector.unsqueeze(1).repeat(1, question_out.size(1), 1)], -1)
        question_span_start_logits = \
            self._question_span_start_predictor(encoded_question_for_span_prediction).squeeze(-1)
        # Shape: (batch_size, question_length)
        question_span_end_logits = \
            self._question_span_end_predictor(encoded_question_for_span_prediction).squeeze(-1)
        question_span_start_log_probs = util.masked_log_softmax(
            question_span_start_logits, question_mask)
        question_span_end_log_probs = util.masked_log_softmax(
            question_span_end_logits, question_mask)

        # Info about the best question span prediction
        question_span_start_logits = \
            util.replace_masked_values(question_span_start_logits, question_mask, -1e7)
        question_span_end_logits = \
            util.replace_masked_values(question_span_end_logits, question_mask, -1e7)

        # Shape: (batch_size, 2)
        best_question_span = get_best_span(question_span_start_logits,
                                           question_span_end_logits)
        return question_span_start_log_probs, question_span_end_log_probs, best_question_span
Beispiel #7
0
    def _passage_span_module(self, passage_out, passage_mask):
        # Shape: (batch_size, passage_length)
        passage_span_start_logits = self._passage_span_start_predictor(
            passage_out).squeeze(-1)

        # Shape: (batch_size, passage_length)
        passage_span_end_logits = self._passage_span_end_predictor(
            passage_out).squeeze(-1)

        # Shape: (batch_size, passage_length)
        passage_span_start_log_probs = util.masked_log_softmax(
            passage_span_start_logits, passage_mask)
        passage_span_end_log_probs = util.masked_log_softmax(
            passage_span_end_logits, passage_mask)

        # Info about the best passage span prediction
        passage_span_start_logits = util.replace_masked_values(
            passage_span_start_logits, passage_mask, -1e7)
        passage_span_end_logits = util.replace_masked_values(
            passage_span_end_logits, passage_mask, -1e7)

        # Shape: (batch_size, 2)
        best_passage_span = get_best_span(passage_span_start_logits,
                                          passage_span_end_logits)
        return passage_span_start_log_probs, passage_span_end_log_probs, best_passage_span
Beispiel #8
0
    def _get_edge_probabilities(self, encoded_premise,
                                mean_node_premise_attention, edge_sources,
                                edge_targets, edge_labels,
                                metadata) -> FloatTensor:
        # dim: batch x nodes x emb. dim
        aggregate_node_premise_lstm_representation = weighted_sum(
            encoded_premise, mean_node_premise_attention)
        # dim: batch x edges x 1
        edge_mask = (edge_sources != -1).float()
        edge_source_lstm_repr = self._select_embeddings_using_index(
            aggregate_node_premise_lstm_representation,
            replace_masked_values(edge_sources.float(), edge_mask, 0))
        edge_target_lstm_repr = self._select_embeddings_using_index(
            aggregate_node_premise_lstm_representation,
            replace_masked_values(edge_targets.float(), edge_mask, 0))
        # edge label embeddings. dim: batch x edges x edge dim
        masked_edge_labels = replace_masked_values(edge_labels.float(),
                                                   edge_mask,
                                                   0).squeeze(2).long()
        edge_label_embeddings = self._edge_embedding(masked_edge_labels)
        # dim: batch x edges x (2* emb dim + edge dim)

        combined_edge_representation = torch.cat([
            edge_source_lstm_repr, edge_label_embeddings, edge_target_lstm_repr
        ], 2)
        edge_prob_distribution = self._edge_probability(
            combined_edge_representation)
        edges_only_mask = edge_mask.expand_as(edge_prob_distribution).float()
        mean_edge_distribution = masked_mean(edge_prob_distribution, 1,
                                             edges_only_mask)
        return mean_edge_distribution
Beispiel #9
0
    def forward(self,
                **kwargs: Dict[str, Any]) -> Dict[str, torch.Tensor]:

        input, mask = self.get_input_and_mask(kwargs)

        # Shape: (batch_size, passage_length)
        start_logits = self._start_output_layer(input).squeeze(-1)

        # Shape: (batch_size, passage_length)
        end_logits = self._end_output_layer(input).squeeze(-1)

        start_log_probs = masked_log_softmax(start_logits, mask)
        end_log_probs = masked_log_softmax(end_logits, mask)

        # Info about the best span prediction
        start_logits = replace_masked_values(start_logits, mask, -1e7)
        end_logits = replace_masked_values(end_logits, mask, -1e7)

        # Shape: (batch_size, 2)
        best_span = get_best_span(start_logits, end_logits)

        output_dict = {
            'start_log_probs': start_log_probs,
            'end_log_probs': end_log_probs,
            'best_span': best_span
        }
        return output_dict
    def forward(
            self,  # type: ignore
            sentence: Dict[str, torch.LongTensor],
            label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        sentence_mask = util.get_text_field_mask(sentence).float()
        embedded_sentence = self._text_field_embedder(sentence)

        dropped_embedded_sent = self._embedding_dropout(embedded_sentence)
        pre_encoded_sent = self._pre_encode_feedforward(dropped_embedded_sent)
        encoded_tokens = self._encoder(pre_encoded_sent, sentence_mask)

        # Compute biattention. This is a special case since the inputs are the same.
        attention_logits = encoded_tokens.bmm(
            encoded_tokens.permute(0, 2, 1).contiguous())
        attention_weights = util.last_dim_softmax(attention_logits,
                                                  sentence_mask)
        encoded_sentence = util.weighted_sum(encoded_tokens, attention_weights)

        # Build the input to the integrator
        integrator_input = torch.cat([
            encoded_tokens, encoded_tokens - encoded_sentence,
            encoded_tokens * encoded_sentence
        ], 2)
        integrated_encodings = self._integrator(integrator_input,
                                                sentence_mask)

        # Simple Pooling layers
        max_masked_integrated_encodings = util.replace_masked_values(
            integrated_encodings, sentence_mask.unsqueeze(2), -1e7)
        max_pool = torch.max(max_masked_integrated_encodings, 1)[0]
        min_masked_integrated_encodings = util.replace_masked_values(
            integrated_encodings, sentence_mask.unsqueeze(2), +1e7)
        min_pool = torch.min(min_masked_integrated_encodings, 1)[0]
        mean_pool = torch.sum(integrated_encodings, 1) / torch.sum(
            sentence_mask, 1, keepdim=True)

        # Self-attentive pooling layer
        # Run through linear projection. Shape: (batch_size, sequence length, 1)
        # Then remove the last dimension to get the proper attention shape (batch_size, sequence length).
        self_attentive_logits = self._self_attentive_pooling_projection(
            integrated_encodings).squeeze(2)
        self_weights = util.masked_softmax(self_attentive_logits,
                                           sentence_mask)
        self_attentive_pool = util.weighted_sum(integrated_encodings,
                                                self_weights)

        pooled_representations = torch.cat(
            [max_pool, min_pool, mean_pool, self_attentive_pool], 1)
        pooled_representations_dropped = self._integrator_dropout(
            pooled_representations).squeeze(1)

        logits = self._output_layer(pooled_representations_dropped)
        output_dict = {'logits': logits}
        if label is not None:
            loss = self.loss(logits, label.squeeze(-1))
            for metric in self.metrics.values():
                metric(logits, label.squeeze(-1))
            output_dict["loss"] = loss

        return output_dict
    def compute_location_spans(self, contextual_seq_embedding, embedded_sentence_verb_entity, mask):
        # # ===============================================================test============================================
        # # Layer 5: Span prediction for before and after location
        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        batch_size, num_sentences, num_participants, sentence_length, encoder_dim = contextual_seq_embedding.shape
        #print("contextual_seq_embedding: ", contextual_seq_embedding.shape)
        # size(span_start_input_after): batch_size * num_sentences *
        #                                num_participants * sentence_length * (embedding_size+2+2*seq2seq_output_size)
        span_start_input_after = torch.cat([embedded_sentence_verb_entity, contextual_seq_embedding], dim=-1)

        #print("span_start_input_after: ", span_start_input_after.shape)
        # Shape: (bs, ns , np, sl)
        span_start_logits_after = self._span_start_predictor_after(span_start_input_after).squeeze(-1)
        #print("span_start_logits_after: ", span_start_logits_after.shape)

        # Shape: (bs, ns , np, sl)
        span_start_probs_after = util.masked_softmax(span_start_logits_after, mask)
        #print("span_start_probs_after: ", span_start_probs_after.shape)

        # span_start_representation_after: (bs, ns , np, encoder_dim)
        span_start_representation_after = util.weighted_sum(contextual_seq_embedding, span_start_probs_after)
        #print("span_start_representation_after: ", span_start_representation_after.shape)

        # span_tiled_start_representation_after: (bs, ns , np, sl, 2*seq2seq_output_size)
        span_tiled_start_representation_after = span_start_representation_after.unsqueeze(3).expand(batch_size,
                                                                                                    num_sentences,
                                                                                                    num_participants,
                                                                                                    sentence_length,
                                                                                                    encoder_dim)
        #print("span_tiled_start_representation_after: ", span_tiled_start_representation_after.shape)

        # Shape: (batch_size, passage_length, (embedding+2  + encoder_dim + encoder_dim + encoder_dim))
        span_end_representation_after = torch.cat([embedded_sentence_verb_entity,
                                                   contextual_seq_embedding,
                                                   span_tiled_start_representation_after,
                                                   contextual_seq_embedding * span_tiled_start_representation_after],
                                                  dim=-1)
        #print("span_end_representation_after: ", span_end_representation_after.shape)

        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end_after = self.time_distributed_encoder_span_end_after(span_end_representation_after, mask)
        #print("encoded_span_end_after: ", encoded_span_end_after.shape)

        span_end_logits_after = self._span_end_predictor_after(encoded_span_end_after).squeeze(-1)
        #print("span_end_logits_after: ", span_end_logits_after.shape)

        span_end_probs_after = util.masked_softmax(span_end_logits_after, mask)
        #print("span_end_probs_after: ", span_end_probs_after.shape)

        span_start_logits_after = util.replace_masked_values(span_start_logits_after, mask, -1e7)
        span_end_logits_after = util.replace_masked_values(span_end_logits_after, mask, -1e7)

        # Fixme: we should condition this on predicted_action so that we can output '-' when needed
        # Fixme: also add a functionality to be able to output '?': we can use span_start_probs_after, span_end_probs_after
        best_span_after = self.get_best_span(span_start_logits_after, span_end_logits_after)
        #print("best_span_after: ", best_span_after)
        return best_span_after, span_start_logits_after, span_end_logits_after
Beispiel #12
0
    def _get_span_answer_log_prob(
            answer_as_spans: torch.LongTensor,
            span_log_probs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        """ Compute the log_marginal_likelihood for the answer_spans given log_probs for start/end
            Compute log_likelihood (product of start/end probs) of each ans_span
            Sum the prob (logsumexp) for each span and return the log_likelihood

        Parameters:
        -----------
        answer: ``torch.LongTensor`` Shape: (number_of_spans, 2)
            These are the gold spans
        span_log_probs: ``torch.FloatTensor``
            2-Tuple with tensors of Shape: (length_of_sequence) for span_start/span_end log_probs

        Returns:
        log_marginal_likelihood_for_passage_span
        """
        # Unsqueezing dim=0 to make a batch_size of 1
        answer_as_spans = answer_as_spans.unsqueeze(0)

        span_start_log_probs, span_end_log_probs = span_log_probs

        span_start_log_probs = span_start_log_probs.unsqueeze(0)
        span_end_log_probs = span_end_log_probs.unsqueeze(0)

        # (batch_size, number_of_ans_spans)
        gold_passage_span_starts = answer_as_spans[:, :, 0]
        gold_passage_span_ends = answer_as_spans[:, :, 1]
        # Some spans are padded with index -1,
        # so we clamp those paddings to 0 and then mask after `torch.gather()`.
        gold_passage_span_mask = (gold_passage_span_starts != -1).long()
        clamped_gold_passage_span_starts = allenutil.replace_masked_values(
            gold_passage_span_starts, gold_passage_span_mask, 0)
        clamped_gold_passage_span_ends = allenutil.replace_masked_values(
            gold_passage_span_ends, gold_passage_span_mask, 0)
        # Shape: (batch_size, # of answer spans)
        log_likelihood_for_span_starts = torch.gather(
            span_start_log_probs, 1, clamped_gold_passage_span_starts)
        log_likelihood_for_span_ends = torch.gather(
            span_end_log_probs, 1, clamped_gold_passage_span_ends)

        # Shape: (batch_size, # of answer spans)
        log_likelihood_for_spans = log_likelihood_for_span_starts + log_likelihood_for_span_ends
        # For those padded spans, we set their log probabilities to be very small negative value
        log_likelihood_for_spans = allenutil.replace_masked_values(
            log_likelihood_for_spans, gold_passage_span_mask, -1e7)

        # Shape: (batch_size, )
        log_marginal_likelihood_for_span = allenutil.logsumexp(
            log_likelihood_for_spans)

        return log_marginal_likelihood_for_span
Beispiel #13
0
 def encode(self, sentence, mask):
     out1, (ht1, ct1) = self.rnn1(sentence)
     # max pool
     emb1, _ = replace_masked_values(out1, mask.unsqueeze(-1),
                                     -1e7).max(dim=1)
     out2, (ht2, ct2) = self.rnn2(sentence, (ht1, ct1))
     # max pool
     emb2, _ = replace_masked_values(out2, mask.unsqueeze(-1),
                                     -1e7).max(dim=1)
     out3, (ht3, ct3) = self.rnn3(sentence, (ht2, ct2))
     emb3, _ = replace_masked_values(out3, mask.unsqueeze(-1),
                                     -1e7).max(dim=1)
     return torch.cat([emb1, emb2, emb3], dim=-1)
    def _arithmetic_module(self, arithmetic_passage_vector, passage_out, number_indices, number_mask):
        if self.number_rep in ['average', 'attention']:
            
            # Shape: (batch_size, # of numbers, # of pieces) 
            number_indices = util.replace_masked_values(number_indices, number_indices != -1, 0).long()
            batch_size = number_indices.shape[0]
            num_numbers = number_indices.shape[1]
            seqlen = passage_out.shape[1]

            # Shape : (batch_size, # of numbers, seqlen)
            mask = torch.zeros((batch_size, num_numbers, seqlen), device=number_indices.device).long().scatter(
                        2, 
                        number_indices, 
                        torch.ones(number_indices.shape, device=number_indices.device).long())
            mask[:,:,0] = 0

            # Shape : (batch_size, # of numbers, seqlen, bert_dim)
            epassage_out = passage_out.unsqueeze(1).repeat(1,num_numbers,1,1)

            # Shape : (batch_size, # of numbers, bert_dim)
            encoded_numbers = self.summary_vector(epassage_out, mask, "numbers")
        else:
            number_indices = number_indices[:,:,0].long()
            clamped_number_indices = util.replace_masked_values(number_indices, number_mask, 0)
            encoded_numbers = torch.gather(
                    passage_out,
                    1,
                    clamped_number_indices.unsqueeze(-1).expand(-1, -1, passage_out.size(-1)))
            
        if self.num_special_numbers > 0:
            special_numbers = self.special_embedding(torch.arange(self.num_special_numbers, device=number_indices.device))
            special_numbers = special_numbers.expand(number_indices.shape[0],-1,-1)
            encoded_numbers = torch.cat([special_numbers, encoded_numbers], 1)
            
            mask = torch.ones((number_indices.shape[0],self.num_special_numbers), device=number_indices.device).long()
            number_mask = torch.cat([mask, number_mask], -1)
            
        # Shape: (batch_size, # of numbers, 2*bert_dim)
        encoded_numbers = torch.cat(
                [encoded_numbers, arithmetic_passage_vector.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1)], -1)
        
        # Shape: (batch_size, #templates, #slots, #numbers)
        arithmetic_template_slot_logits = self._arithmetic_template_slot_predictor(encoded_numbers).transpose(1,2)
        arithmetic_template_slot_log_probs = util.masked_log_softmax(arithmetic_template_slot_logits, number_mask)
        arithmetic_template_slot_log_probs = arithmetic_template_slot_log_probs.reshape(number_mask.shape[0],
                                                                                        self.num_arithmetic_templates, 
                                                                                        self.num_template_slots, 
                                                                                        number_mask.shape[-1])
        # Shape: (batch_size, #templates, #slots)
        arithmetic_best_template_slots = arithmetic_template_slot_log_probs.argmax(-1)
        return arithmetic_template_slot_log_probs, arithmetic_best_template_slots, number_mask
Beispiel #15
0
    def _base_arithmetic_module(self, passage_vector, passage_out, number_indices, number_mask):
        if self.number_rep in ['average', 'attention']:

            # Shape: (batch_size, # of numbers, # of pieces)
            number_indices = util.replace_masked_values(number_indices, number_indices != -1, 0).long()
            batch_size = number_indices.shape[0]
            num_numbers = number_indices.shape[1]
            seqlen = passage_out.shape[1]

            # Shape : (batch_size, # of numbers, seqlen)
            mask = torch.zeros((batch_size, num_numbers, seqlen), device=number_indices.device).long().scatter(
                2,
                number_indices,
                torch.ones(number_indices.shape, device=number_indices.device).long())
            mask[:, :, 0] = 0

            # Shape : (batch_size, # of numbers, seqlen, bert_dim)
            epassage_out = passage_out.unsqueeze(1).repeat(1, num_numbers, 1, 1)

            # Shape : (batch_size, # of numbers, bert_dim)
            encoded_numbers = self.summary_vector(epassage_out, mask, "numbers")
        else:
            number_indices = number_indices[:, :, 0].long()
            clamped_number_indices = util.replace_masked_values(number_indices, number_mask, 0)
            encoded_numbers = torch.gather(
                passage_out,
                1,
                clamped_number_indices.unsqueeze(-1).expand(-1, -1, passage_out.size(-1)))

        if self.num_special_numbers > 0:
            special_numbers = self.special_embedding(
                torch.arange(self.num_special_numbers, device=number_indices.device))
            special_numbers = special_numbers.expand(number_indices.shape[0], -1, -1)
            encoded_numbers = torch.cat([special_numbers, encoded_numbers], 1)

            mask = torch.ones((number_indices.shape[0], self.num_special_numbers), device=number_indices.device).long()
            number_mask = torch.cat([mask, number_mask], -1)
        # Shape: (batch_size, # of numbers, 2*bert_dim)
        encoded_numbers = torch.cat(
            [encoded_numbers, passage_vector.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1)], -1)

        # Shape: (batch_size, # of numbers in the passage, 3)
        number_sign_logits = self._number_sign_predictor(encoded_numbers)
        number_sign_log_probs = torch.nn.functional.log_softmax(number_sign_logits, -1)

        # Shape: (batch_size, # of numbers in passage).
        best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1)
        # For padding numbers, the best sign masked as 0 (not included).
        best_signs_for_numbers = util.replace_masked_values(best_signs_for_numbers, number_mask, 0)
        return number_sign_log_probs, best_signs_for_numbers, number_mask
 def _count_log_likelihood(self, answer_as_counts, count_number_log_probs):
     # Count answers are padded with label -1,
     # so we clamp those paddings to 0 and then mask after `torch.gather()`.
     # Shape: (batch_size, # of count answers)
     gold_count_mask = (answer_as_counts != -1).long()
     # Shape: (batch_size, # of count answers)
     clamped_gold_counts = util.replace_masked_values(answer_as_counts, gold_count_mask, 0)
     log_likelihood_for_counts = torch.gather(count_number_log_probs, 1, clamped_gold_counts)
     # For those padded spans, we set their log probabilities to be very small negative value
     log_likelihood_for_counts = \
         util.replace_masked_values(log_likelihood_for_counts, gold_count_mask, -1e7)
     # Shape: (batch_size, )
     log_marginal_likelihood_for_count = util.logsumexp(log_likelihood_for_counts)
     return log_marginal_likelihood_for_count
    def module(self, bert_out, seq_mask=None):
        logits = self.predictor(bert_out)

        if self._use_crf:
            # The mask should not be applied here when using CRF, but should be passed ot the CRF
            log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
        else:
            if seq_mask is not None:
                log_probs =  replace_masked_values(torch.nn.functional.log_softmax(logits, dim=-1), seq_mask.unsqueeze(-1), 0.0)
                logits =  replace_masked_values(logits, seq_mask.unsqueeze(-1), -1e7)
            else:
                log_probs = torch.nn.functional.log_softmax(logits)

        return log_probs, logits
Beispiel #18
0
    def _count_loss(self, answer_as_counts, count_number, max_prob, min_prob):
        # Count answers are padded with label -1,
        # so we clamp those paddings to 0 and then mask after `torch.gather()`.
        # Shape: (batch_size, # of count answers)
        gold_count_mask = (answer_as_counts != -1).long()
        # Shape: (batch_size,)
        gold_counts_masked = util.replace_masked_values(answer_as_counts, gold_count_mask, 0)
        count_number_masked = util.replace_masked_values(count_number, gold_count_mask, 0)

        huber_loss = self.huber_loss(count_number_masked, gold_counts_masked.float())

        selection_loss = (1 - (max_prob - min_prob)) * 1000

        # Shape: (batch_size, )
        return huber_loss + selection_loss
Beispiel #19
0
def aux_window_loss(ptop_attention, passage_mask, inwindow_mask):
    """Auxiliary loss to encourage p-to-p attention to be within a certain window.

    Args:
        ptop_attention: (passage_length, passage_length)
        passage_mask: (passage_length)
        inwindow_mask: (passage_length, passage_length)

    Returns:
        inwindow_aux_loss: ()
    """
    inwindow_mask = inwindow_mask * passage_mask.unsqueeze(0)
    inwindow_mask = inwindow_mask * passage_mask.unsqueeze(1)
    inwindow_probs = ptop_attention * inwindow_mask
    # Sum inwindow_probs for each token, signifying the token can distribute its alignment prob in any way
    # Shape: (passage_length)
    sum_inwindow_probs = inwindow_probs.sum(1)
    # Shape: (passage_length) -- mask for tokens that have empty windows
    mask_sum = (inwindow_mask.sum(1) > 0).float()
    masked_sum_inwindow_probs = allenutil.replace_masked_values(
        sum_inwindow_probs, mask_sum, replace_with=1e-40)
    log_sum_inwindow_probs = torch.log(masked_sum_inwindow_probs +
                                       1e-40) * mask_sum
    inwindow_likelihood = torch.sum(log_sum_inwindow_probs)
    inwindow_likelihood_avg = inwindow_likelihood / torch.sum(inwindow_mask)

    inwindow_aux_loss = -1.0 * inwindow_likelihood_avg

    return inwindow_aux_loss
Beispiel #20
0
def replace_masked_values_with_big_negative_number(x: torch.Tensor,
                                                   mask: torch.Tensor):
    """
    Replace the masked values in a tensor something really negative so that they won't
    affect a max operation.
    """
    return replace_masked_values(x, mask, min_value_of_dtype(x.dtype))
Beispiel #21
0
def masked_mean(tensor, dim, mask):
    """
    ``Performs a mean on just the non-masked portions of the ``tensor`` in the
    ``dim`` dimension of the tensor.
    """
    if mask is None:
        return torch.mean(tensor, dim)
    '''print("****")
    print(tensor.size())
    print(mask.size())
    print(tensor.dim())
    print(mask.dim())
    print(dim)
    print("****")'''
    if tensor.dim() != mask.dim():
        raise ConfigurationError("tensor.dim() (%d) != mask.dim() (%d)" %
                                 (tensor.dim(), mask.dim()))
    masked_tensor = replace_masked_values(tensor, mask, 0.0)
    # total value
    total_tensor = torch.sum(masked_tensor, dim)
    # count
    count_tensor = torch.sum((mask != 0), dim)
    # set zero count to 1 to avoid nans
    # zero_count_mask = (count_tensor == 0)
    zero_count_mask = (count_tensor == 0).long()
    count_plus_zeros = (count_tensor + zero_count_mask).float()
    # average
    mean_tensor = total_tensor / count_plus_zeros
    return mean_tensor
Beispiel #22
0
def masked_mean(tensor, dim, mask):
    """
    ``Performs a mean on just the non-masked portions of the ``tensor`` in the
    ``dim`` dimension of the tensor.

    =====================================================================
    From Decomposable Graph Entailment Model code replicated from SciTail repo
    https://github.com/allenai/scitail
    =====================================================================
    """
    if mask is None:
        return torch.mean(tensor, dim)
    if tensor.dim() != mask.dim():
        raise ConfigurationError("tensor.dim() (%d) != mask.dim() (%d)" %
                                 (tensor.dim(), mask.dim()))
    masked_tensor = replace_masked_values(tensor, mask, 0.0)
    # total value
    total_tensor = torch.sum(masked_tensor, dim)
    # count
    count_tensor = torch.sum((mask != 0), dim)
    # set zero count to 1 to avoid nans
    zero_count_mask = (count_tensor == 0)
    zero_count_mask = zero_count_mask.long()
    count_plus_zeros = (count_tensor + zero_count_mask).float()
    # average
    mean_tensor = total_tensor / count_plus_zeros
    return mean_tensor
    def gold_log_marginal_likelihood(
            self, gold_answer_representations: Dict[str, torch.LongTensor],
            log_probs: torch.LongTensor,
            question_and_passage_mask: torch.LongTensor,
            passage_mask: torch.LongTensor,
            first_wordpiece_mask: torch.LongTensor,
            is_bio_mask: torch.LongTensor, **kwargs: Any):
        mask = self._get_mask(question_and_passage_mask, passage_mask,
                              first_wordpiece_mask)

        gold_bio_seqs = self._get_gold_answer(gold_answer_representations,
                                              log_probs, mask)
        if self._training_style == 'soft_em':
            log_marginal_likelihood = self._marginal_likelihood(
                gold_bio_seqs, log_probs)
        elif self._training_style == 'hard_em':
            log_marginal_likelihood = self._get_most_likely_likelihood(
                gold_bio_seqs, log_probs)
        else:
            raise Exception("Illegal training_style")

        # For questions without spans, we set their log likelihood to be very small negative value
        log_marginal_likelihood = replace_masked_values(
            log_marginal_likelihood, is_bio_mask, -1e7)

        return log_marginal_likelihood
    def _get_most_likely_likelihood(self, bio_seqs: torch.LongTensor,
                                    log_probs: torch.LongTensor):
        # bio_seqs - Shape: (batch_size, # of correct sequences, seq_length)
        # log_probs - Shape: (batch_size, seq_length, 3)

        # Shape: (batch_size, # of correct sequences, seq_length, 3)
        # duplicate log_probs for each gold bios sequence
        expanded_log_probs = log_probs.unsqueeze(1).expand(
            -1,
            bio_seqs.size()[1], -1, -1)

        # get the log-likelihood per each sequence index
        # Shape: (batch_size, # of correct sequences, seq_length)
        log_likelihoods = \
            torch.gather(expanded_log_probs, dim=-1, index=bio_seqs.unsqueeze(-1)).squeeze(-1)

        # Shape: (batch_size, # of correct sequences)
        correct_sequences_pad_mask = (bio_seqs.sum(-1) > 0).long()

        # Sum the log-likelihoods for each index to get the log-likelihood of the sequence
        # Shape: (batch_size, # of correct sequences)
        sequences_log_likelihoods = log_likelihoods.sum(dim=-1)
        sequences_log_likelihoods = replace_masked_values(
            sequences_log_likelihoods, correct_sequences_pad_mask, -1e7)

        most_likely_sequence_index = sequences_log_likelihoods.argmax(dim=-1)

        return sequences_log_likelihoods.gather(
            dim=1,
            index=most_likely_sequence_index.unsqueeze(-1)).squeeze(dim=-1)
Beispiel #25
0
    def forward(self,
                embedded_input,
                input_mask,
                other_input=None,
                other_mask=None):

        #assumes input is batch_size * num_words * embedding_dim

        if self._hidden_feedforward is not None:
            embedded_input = self._hidden_feedforward(embedded_input)

        to_cat = []
        if self._max_pool:
            input_max, _ = replace_masked_values(embedded_input,
                                                 input_mask.unsqueeze(-1),
                                                 -1e7).max(dim=1)
            to_cat.append(input_max)
        if self._avg_pool:
            input_avg = torch.sum(
                embedded_input * input_mask.float().unsqueeze(-1),
                dim=1) / torch.sum(input_mask.float(), 1, keepdim=True)
            to_cat.append(input_avg)

        output = torch.cat(to_cat, dim=1)

        if self._projection_feedforward is not None:
            output = self._projection_feedforward(output)

        return output
Beispiel #26
0
    def take_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:

        # get the relevant scores for the time step
        class_log_probabilities = state['log_probs'][:,
                                                     state['step_num'][0], :]
        is_wordpiece = (
            1 - state['wordpiece_mask'][:, state['step_num'][0]]).byte()

        # mask illegal BIO transitions
        transitions_mask = torch.cat(
            (torch.ones_like(class_log_probabilities[:, :3]),
             torch.zeros_like(class_log_probabilities[:, -2:])),
            dim=-1).byte()
        transitions_mask[:, 2] &= ((last_predictions == 1) |
                                   (last_predictions == 2))
        transitions_mask[:,
                         1:3] &= ((class_log_probabilities[:, :3]
                                   == 0.0).sum(-1) != 3).unsqueeze(-1).repeat(
                                       1, 2)

        # assuming the wordpiece mask doesn't intersect with the other masks (pad, cls/sep)
        transitions_mask[:, 2] |= is_wordpiece & ((last_predictions == 1) |
                                                  (last_predictions == 2))

        class_log_probabilities = replace_masked_values(
            class_log_probabilities, transitions_mask, -1e7)

        state['step_num'] = state['step_num'].clone() + 1
        return class_log_probabilities, state
Beispiel #27
0
    def posAttnConv(self, sentence, other_sen, interaction, sentence_mask,
                    other_sen_mask, matrix_mask):
        """
        @brief      Compute the position-aware attentive convolution

        @param      self            The object
        @param      sentence        The embeded sentence (n x s x d)
        @param      other_sen       The other sentence (n x s' x d)
        @param      interaction     The interaction matrix (n x s x s')
        @param      sentence_mask   The mask of the sentence (n x s)
        @param      other_sen_mask  The mask of other sentence (n x s')
        @param      matrix_mask     The mask of the interaction matrix (n x s x
                                    s')

        @return     The position-aware attentive convolution
        """
        # calculate the representation of the sentence
        interaction_softmax = last_dim_softmax(
            interaction, other_sen_mask)  # (n x s x s')
        sentence_tilda = weighted_sum(
            other_sen, interaction_softmax)  # (n x s x d)

        # get index of the best-matched word
        _, x = replace_masked_values(interaction, matrix_mask,
                                     -1e7).max(dim=-1)  # (n x s)
        z = self._pos_embedder(x)  # (n x s x dm)

        sentence_combined = torch.cat((sentence_tilda, sentence, z),
                                      dim=2)  # (n x s x (2d + dm))

        return self._pos_attn_encoder(sentence_combined, sentence_mask)
Beispiel #28
0
    def _get_combined_likelihood(self, answer_as_list_of_bios, log_probs):
        # answer_as_list_of_bios - Shape: (batch_size, # of correct sequences, seq_length)
        # log_probs - Shape: (batch_size, seq_length, 3)

        # Shape: (batch_size, # of correct sequences, seq_length, 3)
        # duplicate log_probs for each gold bios sequence
        expanded_log_probs = log_probs.unsqueeze(1).expand(
            -1,
            answer_as_list_of_bios.size()[1], -1, -1)

        # get the log-likelihood per each sequence index
        # Shape: (batch_size, # of correct sequences, seq_length)
        log_likelihoods = \
            torch.gather(expanded_log_probs, dim=-1, index=answer_as_list_of_bios.unsqueeze(-1)).squeeze(-1)

        # Shape: (batch_size, # of correct sequences)
        correct_sequences_pad_mask = (answer_as_list_of_bios.sum(-1) >
                                      0).long()

        # Sum the log-likelihoods for each index to get the log-likelihood of the sequence
        # Shape: (batch_size, # of correct sequences)
        sequences_log_likelihoods = log_likelihoods.sum(dim=-1)
        sequences_log_likelihoods = replace_masked_values(
            sequences_log_likelihoods, correct_sequences_pad_mask, -1e7)

        # Sum the log-likelihoods for each sequence to get the marginalized log-likelihood over the correct answers
        log_marginal_likelihood = logsumexp(sequences_log_likelihoods, dim=-1)

        return log_marginal_likelihood
Beispiel #29
0
    def forward(self, token_representations: torch.LongTensor,
                passage_summary_vector: torch.LongTensor,
                number_indices: torch.LongTensor,
                **kwargs: Dict[str, Any]) -> Dict[str, torch.Tensor]:

        number_mask = self._get_mask(number_indices,
                                     with_special_numbers=False)

        clamped_number_indices = replace_masked_values(
            number_indices[:, :, 0].long(), number_mask, 0)
        encoded_numbers = torch.gather(
            token_representations, 1,
            clamped_number_indices.unsqueeze(-1).expand(
                -1, -1, token_representations.size(-1)))

        if self._num_special_numbers > 0:
            special_numbers = self._special_embeddings(
                torch.arange(self._num_special_numbers,
                             device=number_indices.device))
            special_numbers = special_numbers.expand(number_indices.shape[0],
                                                     -1, -1)
            encoded_numbers = torch.cat([special_numbers, encoded_numbers], 1)

        # Shape: (batch_size, # of numbers, 2*bert_dim)
        encoded_numbers = torch.cat([
            encoded_numbers,
            passage_summary_vector.unsqueeze(1).repeat(
                1, encoded_numbers.size(1), 1)
        ], -1)

        # Shape: (batch_size, # of numbers in the passage, 3)
        logits = self._output_layer(encoded_numbers)
        log_probs = torch.nn.functional.log_softmax(logits, -1)

        number_mask = self._get_mask(number_indices, with_special_numbers=True)
        # Shape: (batch_size, # of numbers in passage).
        best_signs_for_numbers = torch.argmax(log_probs, -1)
        # For padding numbers, the best sign masked as 0 (not included).
        best_signs_for_numbers = replace_masked_values(best_signs_for_numbers,
                                                       number_mask, 0)

        output_dict = {
            'log_probs': log_probs,
            'logits': logits,
            'best_signs_for_numbers': best_signs_for_numbers
        }
        return output_dict
Beispiel #30
0
    def forward(
            self,  # type: ignore
            utterance: Dict[str, torch.LongTensor],
            logical_forms: Dict[str, torch.LongTensor],
            utterance_string: List[str],
            logical_form_strings: List[List[str]]) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------

        Returns
        -------

        """
        # (batch_size, num_utterance_tokens, utterance_embedding_dim)
        embedded_utterance = self.utterance_embedder(utterance)

        # (batch_size, num_logical_forms, num_lf_tokens, lf_embedding_dim)
        embedded_logical_forms = self.logical_form_embedder(
            logical_forms, num_wrapping_dims=1)

        # (batch_size, num_logical_forms, num_lf_tokens)
        logical_form_token_mask = util.get_text_field_mask(logical_forms,
                                                           num_wrapping_dims=1)
        # (batch_size, num_logical_forms)
        logical_form_mask = logical_form_token_mask.sum(dim=-1).clamp(max=1)

        # Because we're just summing everything in the end, we can do the sum upfront to save some
        # time.
        # (batch_size, utterance_embedding_dim)
        encoded_utterance = embedded_utterance.sum(dim=1)

        # (batch_size, num_logical_forms, lf_embedding_dim)
        encoded_logical_forms = embedded_logical_forms.sum(dim=2)

        # (batch_size, num_logical_forms, utterance_embedding_dim)
        predicted_embeddings = self.translation_layer(encoded_logical_forms)

        # (batch_size, num_logical_forms)
        similarities = torch.nn.functional.cosine_similarity(
            predicted_embeddings, encoded_utterance.unsqueeze(1), dim=2)

        # Make sure masked logical forms aren't included in the max.
        similarities = util.replace_masked_values(similarities,
                                                  logical_form_mask, -1e7)

        max_similarity, most_similar = similarities.max(dim=-1)
        loss = (1 - max_similarity).sum()
        most_similar_strings = []
        for instance_most_similar, instance_logical_forms in zip(
                most_similar.tolist(), logical_form_strings):
            most_similar_strings.append(
                instance_logical_forms[instance_most_similar])
        return {
            "loss": loss,
            "most_similar": most_similar_strings,
            "utterance": utterance_string
        }
Beispiel #31
0
    def forward(self, # pylint: disable=arguments-differ
                embeddings: torch.FloatTensor,
                mask: torch.LongTensor,
                num_items_to_keep: int) -> Tuple[torch.FloatTensor, torch.LongTensor,
                                                 torch.LongTensor, torch.FloatTensor]:
        """
        Extracts the top-k scoring items with respect to the scorer. We additionally return
        the indices of the top-k in their original order, not ordered by score, so that downstream
        components can rely on the original ordering (e.g., for knowing what spans are valid
        antecedents in a coreference resolution model).

        Parameters
        ----------
        embeddings : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, num_items, embedding_size), containing an embedding for
            each item in the list that we want to prune.
        mask : ``torch.LongTensor``, required.
            A tensor of shape (batch_size, num_items), denoting unpadded elements of
            ``embeddings``.
        num_items_to_keep : ``int``, required.
            The number of items to keep when pruning.

        Returns
        -------
        top_embeddings : ``torch.FloatTensor``
            The representations of the top-k scoring items.
            Has shape (batch_size, num_items_to_keep, embedding_size).
        top_mask : ``torch.LongTensor``
            The corresponding mask for ``top_embeddings``.
            Has shape (batch_size, num_items_to_keep).
        top_indices : ``torch.IntTensor``
            The indices of the top-k scoring items into the original ``embeddings``
            tensor. This is returned because it can be useful to retain pointers to
            the original items, if each item is being scored by multiple distinct
            scorers, for instance. Has shape (batch_size, num_items_to_keep).
        top_item_scores : ``torch.FloatTensor``
            The values of the top-k scoring items.
            Has shape (batch_size, num_items_to_keep, 1).
        """
        mask = mask.unsqueeze(-1)
        num_items = embeddings.size(1)
        # Shape: (batch_size, num_items, 1)
        scores = self._scorer(embeddings)

        if scores.size(-1) != 1 or scores.dim() != 3:
            raise ValueError(f"The scorer passed to Pruner must produce a tensor of shape"
                             f"(batch_size, num_items, 1), but found shape {scores.size()}")
        # Make sure that we don't select any masked items by setting their scores to be very
        # negative.  These are logits, typically, so -1e20 should be plenty negative.
        scores = util.replace_masked_values(scores, mask, -1e20)

        # Shape: (batch_size, num_items_to_keep, 1)
        _, top_indices = scores.topk(num_items_to_keep, 1)

        # Now we order the selected indices in increasing order with
        # respect to their indices (and hence, with respect to the
        # order they originally appeared in the ``embeddings`` tensor).
        top_indices, _ = torch.sort(top_indices, 1)

        # Shape: (batch_size, num_items_to_keep)
        top_indices = top_indices.squeeze(-1)

        # Shape: (batch_size * num_items_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select items for each element in the batch.
        flat_top_indices = util.flatten_and_batch_shift_indices(top_indices, num_items)

        # Shape: (batch_size, num_items_to_keep, embedding_size)
        top_embeddings = util.batched_index_select(embeddings, top_indices, flat_top_indices)
        # Shape: (batch_size, num_items_to_keep)
        top_mask = util.batched_index_select(mask, top_indices, flat_top_indices)

        # Shape: (batch_size, num_items_to_keep, 1)
        top_scores = util.batched_index_select(scores, top_indices, flat_top_indices)

        return top_embeddings, top_mask.squeeze(-1), top_indices, top_scores
matrix_attention = LegacyMatrixAttention(similarity_function)

passage_question_similarity = matrix_attention(encoded_passage, encoded_question)
# Shape: (batch_size, passage_length, question_length)
print ("passage question similarity: ", passage_question_similarity.shape)


# Shape: (batch_size, passage_length, question_length)
passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
# Shape: (batch_size, passage_length, encoding_dim)
passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

# We replace masked values with something really negative here, so they don't affect the
# max below.
masked_similarity = util.replace_masked_values(passage_question_similarity,
                                               question_mask.unsqueeze(1),
                                               -1e7)
# Shape: (batch_size, passage_length)
question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
# Shape: (batch_size, passage_length)
question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
# Shape: (batch_size, encoding_dim)
question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
# Shape: (batch_size, passage_length, encoding_dim)
tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                            passage_length,
                                                                            encoding_dim)

# Shape: (batch_size, passage_length, encoding_dim * 4)
final_merged_passage = torch.cat([encoded_passage,
                                  passage_question_vectors,
Beispiel #33
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(self._text_field_embedder(question))
        embedded_passage = self._highway_layer(self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
                "passage_question_attention": passage_question_attention,
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,
                }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
Beispiel #34
0
 def test_replace_masked_values_replaces_masked_values_with_finite_value(self):
     tensor = torch.FloatTensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]])
     mask = torch.FloatTensor([[1, 1, 0]])
     replaced = util.replace_masked_values(tensor, mask.unsqueeze(-1), 2).data.numpy()
     assert_almost_equal(replaced, [[[1, 2, 3, 4], [5, 6, 7, 8], [2, 2, 2, 2]]])
Beispiel #35
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                p1_answer_marker: torch.IntTensor = None,
                p2_answer_marker: torch.IntTensor = None,
                p3_answer_marker: torch.IntTensor = None,
                yesno_list: torch.IntTensor = None,
                followup_list: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        p1_answer_marker : ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 0.
            This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length].
            Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer
            in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>.
            For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac
        p2_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 1.
            It is similar to p1_answer_marker, but marking previous previous answer in passage.
        p3_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 2.
            It is similar to p1_answer_marker, but marking previous previous previous answer in passage.
        yesno_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (the yes/no/not a yes no question).
        followup_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (followup / maybe followup / don't followup).
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of the followings.
        Each of the followings is a nested list because first iterates over dialog, then questions in dialog.

        qid : List[List[str]]
            A list of list, consisting of question ids.
        followup : List[List[int]]
            A list of list, consisting of continuation marker prediction index.
            (y :yes, m: maybe follow up, n: don't follow up)
        yesno : List[List[int]]
            A list of list, consisting of affirmation marker prediction index.
            (y :yes, x: not a yes/no question, n: np)
        best_span_str : List[List[str]]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(followup_list, 0).view(total_qa_count)
        embedded_question = self._text_field_embedder(question, num_wrapping_dims=1)
        embedded_question = embedded_question.reshape(total_qa_count, max_q_len,
                                                      self._text_field_embedder.get_output_dim())
        embedded_question = self._variational_dropout(embedded_question)
        embedded_passage = self._variational_dropout(self._text_field_embedder(passage))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float()
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length)

        if self._num_context_answers > 0:
            # Encode question turn number inside the dialog into question embedding.
            question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len)
            question_num_marker_emb = self._question_num_marker(question_num_ind)
            embedded_question = torch.cat([embedded_question, question_num_marker_emb], dim=-1)

            # Encode the previous answers in passage embedding.
            repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \
                view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim())
            # batch_size * max_qa_count, passage_length, word_embed_dim
            p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length)
            p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker)
            repeated_embedded_passage = torch.cat([repeated_embedded_passage, p1_answer_marker_emb], dim=-1)
            if self._num_context_answers > 1:
                p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length)
                p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker)
                repeated_embedded_passage = torch.cat([repeated_embedded_passage, p2_answer_marker_emb], dim=-1)
                if self._num_context_answers > 2:
                    p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length)
                    p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker)
                    repeated_embedded_passage = torch.cat([repeated_embedded_passage, p3_answer_marker_emb],
                                                          dim=-1)

            repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage,
                                                                                    repeated_passage_mask))
        else:
            encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask))
            repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1)
            repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count,
                                                                     passage_length,
                                                                     self._encoding_dim)

        encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask))

        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)

        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count,
                                                                                    passage_length,
                                                                                    self._encoding_dim)

        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([repeated_encoded_passage,
                                          passage_question_vectors,
                                          repeated_encoded_passage * passage_question_vectors,
                                          repeated_encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage,
                                                                          repeated_passage_mask))
        self_attention_matrix = self._self_attention(residual_layer, residual_layer)

        mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \
               * repeated_passage_mask.reshape(total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_attention_vecs = torch.matmul(self_attention_probs, residual_layer)
        self_attention_vecs = torch.cat([self_attention_vecs, residual_layer,
                                         residual_layer * self_attention_vecs],
                                        dim=-1)
        residual_layer = F.relu(self._merge_self_attention(self_attention_vecs))

        final_merged_passage = final_merged_passage + residual_layer
        # batch_size * maxqa_pair_len * max_passage_len * 200
        final_merged_passage = self._variational_dropout(final_merged_passage)
        start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask)
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)

        end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1),
                                         repeated_passage_mask)
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)

        span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
        span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1)

        span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7)
        # batch_size * maxqa_len_pair, max_document_len
        span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7)

        best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits,
                                                       span_yesno_logits, span_followup_logits,
                                                       self._max_span_length)

        output_dict: Dict[str, Any] = {}

        # Compute the loss.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1),
                            ignore_index=-1)
            self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     repeated_passage_mask), span_end.view(-1), ignore_index=-1)
            self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end], -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)

            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1)
            loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask)
            self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask)
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['followup'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_followup_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index])

                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]

                yesno_pred = predicted_span[2]
                followup_pred = predicted_span[3]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_followup_list.append(followup_pred)
                per_dialog_query_id_list.append(iid)

                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                                 best_span_string,
                                                                                 refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                            best_span_string,
                                                                            answer_texts)
                self._official_f1(100 * f1_score)
            output_dict['qid'].append(per_dialog_query_id_list)
            output_dict['best_span_str'].append(per_dialog_best_span_list)
            output_dict['yesno'].append(per_dialog_yesno_list)
            output_dict['followup'].append(per_dialog_followup_list)
        return output_dict
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``.
        label : torch.LongTensor, optional (default = None)
            A variable representing the label for each instance in the batch.
        Returns
        -------
        An output dictionary consisting of:
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_classes)`` representing a
            distribution over the label classes for each instance.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        text_mask = util.get_text_field_mask(tokens).float()
        # Pop elmo tokens, since elmo embedder should not be present.
        elmo_tokens = tokens.pop("elmo", None)
        embedded_text = self._text_field_embedder(tokens)

        # Add the "elmo" key back to "tokens" if not None, since the tests and the
        # subsequent training epochs rely not being modified during forward()
        if elmo_tokens is not None:
            tokens["elmo"] = elmo_tokens

        # Create ELMo embeddings if applicable
        if self._elmo:
            if elmo_tokens is not None:
                elmo_representations = self._elmo(elmo_tokens)["elmo_representations"]
                # Pop from the end is more performant with list
                if self._use_integrator_output_elmo:
                    integrator_output_elmo = elmo_representations.pop()
                if self._use_input_elmo:
                    input_elmo = elmo_representations.pop()
                assert not elmo_representations
            else:
                raise ConfigurationError(
                        "Model was built to use Elmo, but input text is not tokenized for Elmo.")

        if self._use_input_elmo:
            embedded_text = torch.cat([embedded_text, input_elmo], dim=-1)

        dropped_embedded_text = self._embedding_dropout(embedded_text)
        pre_encoded_text = self._pre_encode_feedforward(dropped_embedded_text)
        encoded_tokens = self._encoder(pre_encoded_text, text_mask)

        # Compute biattention. This is a special case since the inputs are the same.
        attention_logits = encoded_tokens.bmm(encoded_tokens.permute(0, 2, 1).contiguous())
        attention_weights = util.last_dim_softmax(attention_logits, text_mask)
        encoded_text = util.weighted_sum(encoded_tokens, attention_weights)

        # Build the input to the integrator
        integrator_input = torch.cat([encoded_tokens,
                                      encoded_tokens - encoded_text,
                                      encoded_tokens * encoded_text], 2)
        integrated_encodings = self._integrator(integrator_input, text_mask)

        # Concatenate ELMo representations to integrated_encodings if specified
        if self._use_integrator_output_elmo:
            integrated_encodings = torch.cat([integrated_encodings,
                                              integrator_output_elmo], dim=-1)

        # Simple Pooling layers
        max_masked_integrated_encodings = util.replace_masked_values(
                integrated_encodings, text_mask.unsqueeze(2), -1e7)
        max_pool = torch.max(max_masked_integrated_encodings, 1)[0]
        min_masked_integrated_encodings = util.replace_masked_values(
                integrated_encodings, text_mask.unsqueeze(2), +1e7)
        min_pool = torch.min(min_masked_integrated_encodings, 1)[0]
        mean_pool = torch.sum(integrated_encodings, 1) / torch.sum(text_mask, 1, keepdim=True)

        # Self-attentive pooling layer
        # Run through linear projection. Shape: (batch_size, sequence length, 1)
        # Then remove the last dimension to get the proper attention shape (batch_size, sequence length).
        self_attentive_logits = self._self_attentive_pooling_projection(
                integrated_encodings).squeeze(2)
        self_weights = util.masked_softmax(self_attentive_logits, text_mask)
        self_attentive_pool = util.weighted_sum(integrated_encodings, self_weights)

        pooled_representations = torch.cat([max_pool, min_pool, mean_pool, self_attentive_pool], 1)
        pooled_representations_dropped = self._integrator_dropout(pooled_representations)

        logits = self._output_layer(pooled_representations_dropped)
        class_probabilities = F.softmax(logits, dim=-1)

        output_dict = {'logits': logits, 'class_probabilities': class_probabilities}
        if label is not None:
            loss = self.loss(logits, label)
            for metric in self.metrics.values():
                metric(logits, label)
            output_dict["loss"] = loss

        return output_dict
Beispiel #37
0
    def forward(self,  # type: ignore
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
               ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_premise = self.rnn_input_dropout(embedded_premise)
            embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis)

        # encode premise and hypothesis
        encoded_premise = self._encoder(embedded_premise, premise_mask)
        encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask)

        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(encoded_premise, h2p_attention)

        # the "enhancement" layer
        premise_enhanced = torch.cat(
                [encoded_premise, attended_hypothesis,
                 encoded_premise - attended_hypothesis,
                 encoded_premise * attended_hypothesis],
                dim=-1
        )
        hypothesis_enhanced = torch.cat(
                [encoded_hypothesis, attended_premise,
                 encoded_hypothesis - attended_premise,
                 encoded_hypothesis * attended_premise],
                dim=-1
        )

        # The projection layer down to the model dimension.  Dropout is not applied before
        # projection.
        projected_enhanced_premise = self._projection_feedforward(premise_enhanced)
        projected_enhanced_hypothesis = self._projection_feedforward(hypothesis_enhanced)

        # Run the inference layer
        if self.rnn_input_dropout:
            projected_enhanced_premise = self.rnn_input_dropout(projected_enhanced_premise)
            projected_enhanced_hypothesis = self.rnn_input_dropout(projected_enhanced_hypothesis)
        v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask)
        v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask)

        # The pooling layer -- max and avg pooling.
        # (batch_size, model_dim)
        v_a_max, _ = replace_masked_values(
                v_ai, premise_mask.unsqueeze(-1), -1e7
        ).max(dim=1)
        v_b_max, _ = replace_masked_values(
                v_bi, hypothesis_mask.unsqueeze(-1), -1e7
        ).max(dim=1)

        v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum(
                premise_mask, 1, keepdim=True
        )
        v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum(
                hypothesis_mask, 1, keepdim=True
        )

        # Now concat
        # (batch_size, model_dim * 2 * 4)
        v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        if self.dropout:
            v_all = self.dropout(v_all)

        output_hidden = self._output_feedforward(v_all)
        label_logits = self._output_logit(output_hidden)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits, "label_probs": label_probs}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict