示例#1
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None,
            get_sample_level_information=True) -> Dict[str, torch.Tensor]:
        """
        WE LOAD THE MODELS ONE INTO GPU  ONE AT A TIME !!!
        """

        subresults = []
        for submodel in self.submodels:
            submodel.to(device=submodel.cf_a.device)
            subres = submodel(question, passage, span_start, span_end,
                              metadata, get_sample_level_information)
            submodel.to(device=torch.device("cpu"))
            subresults.append(subres)

        batch_size = len(subresults[0]["best_span"])

        best_span = merge_span_probs(subresults)
        output = {
            "best_span": best_span,
            "best_span_str": [],
            "models_output": subresults
        }
        if (get_sample_level_information):
            output["em_samples"] = []
            output["f1_samples"] = []

        for index in range(batch_size):
            if metadata is not None:
                passage_str = metadata[index]['original_passage']
                offsets = metadata[index]['token_offsets']
                predicted_span = tuple(best_span[index].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output["best_span_str"].append(best_span_string)

                answer_texts = metadata[index].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
                    if (get_sample_level_information):
                        em_sample, f1_sample = bidut.get_em_f1_metrics(
                            best_span_string, answer_texts)
                        output["em_samples"].append(em_sample)
                        output["f1_samples"].append(f1_sample)
        if (get_sample_level_information):
            # Add information about the individual samples for future analysis
            output["span_start_sample_loss"] = []
            output["span_end_sample_loss"] = []
            for i in range(batch_size):

                span_start_probs = sum(
                    subresult['span_start_probs']
                    for subresult in subresults) / len(subresults)
                span_end_probs = sum(
                    subresult['span_end_probs']
                    for subresult in subresults) / len(subresults)
                span_start_loss = nll_loss(span_start_probs[[i], :],
                                           span_start.squeeze(-1)[[i]])
                span_end_loss = nll_loss(span_end_probs[[i], :],
                                         span_end.squeeze(-1)[[i]])

                output["span_start_sample_loss"].append(
                    float(span_start_loss.detach().cpu().numpy()))
                output["span_end_sample_loss"].append(
                    float(span_end_loss.detach().cpu().numpy()))
        return output
示例#2
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None,
                get_sample_level_information = False,
                get_attentions = False) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.
        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        
        """
        #################### Sample Bayesian weights ##################
        """
        self.sample_posterior()
        
        """
        ################## MASK COMPUTING ########################
        """
                
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None
        
        """
        ###################### EMBEDDING + HIGHWAY LAYER ########################
        """
#        self.cf_a.use_ELMO
        
        if(self.cf_a.Add_Linear_projection_ELMO):
            embedded_question = self._highway_layer(self._linear_projection_ELMO (self._text_field_embedder(question['character_ids'])["elmo_representations"][-1]))
            embedded_passage = self._highway_layer(self._linear_projection_ELMO(self._text_field_embedder(passage['character_ids'])["elmo_representations"][-1]))
        else:
            embedded_question = self._highway_layer(self._text_field_embedder(question['character_ids'])["elmo_representations"][-1])
            embedded_passage = self._highway_layer(self._text_field_embedder(passage['character_ids'])["elmo_representations"][-1])
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        
        """
        ###################### phrase_layer LAYER ########################
        """

        encoded_question = self._dropout_phrase_layer(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout_phrase_layer(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        """
        ###################### Attention LAYER ########################
        """
        
        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout_modeling_passage(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)
        
        """
        ###################### Spans LAYER ########################
        """
        
        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout_spans_output(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout_span_end_encode(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout_spans_output(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        
        best_span = bidut.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,
                }

        # Compute the loss for training.
        if span_start is not None:
            
            span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            loss = span_start_loss + span_end_loss

            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            
            output_dict["loss"] = loss
            output_dict["span_start_loss"] = span_start_loss
            output_dict["span_end_loss"] = span_end_loss
            
        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            if (get_sample_level_information):
                output_dict["em_samples"] = []
                output_dict["f1_samples"] = []
                
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
                    if (get_sample_level_information):
                        em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts)
                        output_dict["em_samples"].append(em_sample)
                        output_dict["f1_samples"].append(f1_sample)
                        
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
            
        if (get_sample_level_information):
            # Add information about the individual samples for future analysis
            output_dict["span_start_sample_loss"] = []
            output_dict["span_end_sample_loss"] = []
            for i in range (batch_size):
                span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits[[i],:], passage_mask[[i],:]), span_start.squeeze(-1)[[i]])
                span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits[[i],:], passage_mask[[i],:]), span_end.squeeze(-1)[[i]])
                
                output_dict["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy()))
                output_dict["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy()))
        if(get_attentions):
            output_dict["C2Q_attention"] = passage_question_attention
            output_dict["Q2C_attention"] = question_passage_attention
            output_dict["simmilarity"] = passage_question_similarity
            
        return output_dict
示例#3
0
    def forward_ensemble(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None,
                get_sample_level_information = False) -> Dict[str, torch.Tensor]:
        """
        Sample 10 times and add them together
        """
        self.set_posterior_mean(True)
        most_likely_output = self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information)
        self.set_posterior_mean(False)
       
        subresults = [most_likely_output]
        for i in range(10):
           subresults.append(self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information))

        batch_size = len(subresults[0]["best_span"])

        best_span = bidut.merge_span_probs(subresults)
        
        output = {
                "best_span": best_span,
                "best_span_str": [],
                "models_output": subresults
        }
        if (get_sample_level_information):
            output["em_samples"] = []
            output["f1_samples"] = []
                
        for index in range(batch_size):
            if metadata is not None:
                passage_str = metadata[index]['original_passage']
                offsets = metadata[index]['token_offsets']
                predicted_span = tuple(best_span[index].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output["best_span_str"].append(best_span_string)

                answer_texts = metadata[index].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
                    if (get_sample_level_information):
                        em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts)
                        output["em_samples"].append(em_sample)
                        output["f1_samples"].append(f1_sample)
                        
        if (get_sample_level_information):
            # Add information about the individual samples for future analysis
            output["span_start_sample_loss"] = []
            output["span_end_sample_loss"] = []
            for i in range (batch_size):
                
                span_start_probs = sum(subresult['span_start_probs'] for subresult in subresults) / len(subresults)
                span_end_probs = sum(subresult['span_end_probs'] for subresult in subresults) / len(subresults)
                span_start_loss = nll_loss(span_start_probs[[i],:], span_start.squeeze(-1)[[i]])
                span_end_loss = nll_loss(span_end_probs[[i],:], span_end.squeeze(-1)[[i]])
                
                output["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy()))
                output["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy()))
        return output
示例#4
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None,
            get_sample_level_information=False,
            get_attentions=False) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.
        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        """
        #################### Sample Bayesian weights ##################
        """
        self.sample_posterior()
        """
        ################## MASK COMPUTING ########################
        """

        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None
        """
        ###################### EMBEDDING + HIGHWAY LAYER ########################
        """
        #        self.cf_a.use_ELMO

        if (self.cf_a.Add_Linear_projection_ELMO):
            embedded_question = self._highway_layer(
                self._linear_projection_ELMO(
                    self._text_field_embedder(question['character_ids'])
                    ["elmo_representations"][-1]))
            embedded_passage = self._highway_layer(
                self._linear_projection_ELMO(
                    self._text_field_embedder(
                        passage['character_ids'])["elmo_representations"][-1]))
        else:
            embedded_question = self._highway_layer(
                self._text_field_embedder(
                    question['character_ids'])["elmo_representations"][-1])
            embedded_passage = self._highway_layer(
                self._text_field_embedder(
                    passage['character_ids'])["elmo_representations"][-1])
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        """
        ###################### phrase_layer LAYER ########################
        """

        encoded_question = self._dropout_phrase_layer(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout_phrase_layer(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)
        """
        ###################### Attention LAYER ########################
        """

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        modeled_passage = self._dropout_modeling_passage(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)
        """
        ###################### Spans LAYER ########################
        """

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout_spans_output(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        ],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout_span_end_encode(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout_spans_output(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)

        best_span = bidut.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:

            span_start_loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            span_end_loss = nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            loss = span_start_loss + span_end_loss

            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))

            output_dict["loss"] = loss
            output_dict["span_start_loss"] = span_start_loss
            output_dict["span_end_loss"] = span_end_loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            if (get_sample_level_information):
                output_dict["em_samples"] = []
                output_dict["f1_samples"] = []

            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
                    if (get_sample_level_information):
                        em_sample, f1_sample = bidut.get_em_f1_metrics(
                            best_span_string, answer_texts)
                        output_dict["em_samples"].append(em_sample)
                        output_dict["f1_samples"].append(f1_sample)

            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens

        if (get_sample_level_information):
            # Add information about the individual samples for future analysis
            output_dict["span_start_sample_loss"] = []
            output_dict["span_end_sample_loss"] = []
            for i in range(batch_size):
                span_start_loss = nll_loss(
                    util.masked_log_softmax(span_start_logits[[i], :],
                                            passage_mask[[i], :]),
                    span_start.squeeze(-1)[[i]])
                span_end_loss = nll_loss(
                    util.masked_log_softmax(span_end_logits[[i], :],
                                            passage_mask[[i], :]),
                    span_end.squeeze(-1)[[i]])

                output_dict["span_start_sample_loss"].append(
                    float(span_start_loss.detach().cpu().numpy()))
                output_dict["span_end_sample_loss"].append(
                    float(span_end_loss.detach().cpu().numpy()))
        if (get_attentions):
            output_dict["C2Q_attention"] = passage_question_attention
            output_dict["Q2C_attention"] = question_passage_attention
            output_dict["simmilarity"] = passage_question_similarity

        return output_dict
示例#5
0
    def forward_ensemble(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None,
            get_sample_level_information=False) -> Dict[str, torch.Tensor]:
        """
        Sample 10 times and add them together
        """
        self.set_posterior_mean(True)
        most_likely_output = self.forward(question, passage, span_start,
                                          span_end, metadata,
                                          get_sample_level_information)
        self.set_posterior_mean(False)

        subresults = [most_likely_output]
        for i in range(10):
            subresults.append(
                self.forward(question, passage, span_start, span_end, metadata,
                             get_sample_level_information))

        batch_size = len(subresults[0]["best_span"])

        best_span = bidut.merge_span_probs(subresults)

        output = {
            "best_span": best_span,
            "best_span_str": [],
            "models_output": subresults
        }
        if (get_sample_level_information):
            output["em_samples"] = []
            output["f1_samples"] = []

        for index in range(batch_size):
            if metadata is not None:
                passage_str = metadata[index]['original_passage']
                offsets = metadata[index]['token_offsets']
                predicted_span = tuple(best_span[index].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output["best_span_str"].append(best_span_string)

                answer_texts = metadata[index].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
                    if (get_sample_level_information):
                        em_sample, f1_sample = bidut.get_em_f1_metrics(
                            best_span_string, answer_texts)
                        output["em_samples"].append(em_sample)
                        output["f1_samples"].append(f1_sample)

        if (get_sample_level_information):
            # Add information about the individual samples for future analysis
            output["span_start_sample_loss"] = []
            output["span_end_sample_loss"] = []
            for i in range(batch_size):

                span_start_probs = sum(
                    subresult['span_start_probs']
                    for subresult in subresults) / len(subresults)
                span_end_probs = sum(
                    subresult['span_end_probs']
                    for subresult in subresults) / len(subresults)
                span_start_loss = nll_loss(span_start_probs[[i], :],
                                           span_start.squeeze(-1)[[i]])
                span_end_loss = nll_loss(span_end_probs[[i], :],
                                         span_end.squeeze(-1)[[i]])

                output["span_start_sample_loss"].append(
                    float(span_start_loss.detach().cpu().numpy()))
                output["span_end_sample_loss"].append(
                    float(span_end_loss.detach().cpu().numpy()))
        return output
示例#6
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None,
                get_sample_level_information = True) -> Dict[str, torch.Tensor]:
        """
        WE LOAD THE MODELS ONE INTO GPU  ONE AT A TIME !!!
        """
        
        subresults = []
        for submodel in self.submodels:
            submodel.to(device = submodel.cf_a.device)
            subres = submodel(question, passage, span_start, span_end, metadata, get_sample_level_information)
            submodel.to(device = torch.device("cpu"))
            subresults.append(subres)

        batch_size = len(subresults[0]["best_span"])

        best_span = merge_span_probs(subresults)
        output = {
                "best_span": best_span,
                "best_span_str": [],
                "models_output": subresults
        }
        if (get_sample_level_information):
            output["em_samples"] = []
            output["f1_samples"] = []
                
        for index in range(batch_size):
            if metadata is not None:
                passage_str = metadata[index]['original_passage']
                offsets = metadata[index]['token_offsets']
                predicted_span = tuple(best_span[index].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output["best_span_str"].append(best_span_string)

                answer_texts = metadata[index].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
                    if (get_sample_level_information):
                        em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts)
                        output["em_samples"].append(em_sample)
                        output["f1_samples"].append(f1_sample)
        if (get_sample_level_information):
            # Add information about the individual samples for future analysis
            output["span_start_sample_loss"] = []
            output["span_end_sample_loss"] = []
            for i in range (batch_size):
                
                span_start_probs = sum(subresult['span_start_probs'] for subresult in subresults) / len(subresults)
                span_end_probs = sum(subresult['span_end_probs'] for subresult in subresults) / len(subresults)
                span_start_loss = nll_loss(span_start_probs[[i],:], span_start.squeeze(-1)[[i]])
                span_end_loss = nll_loss(span_end_probs[[i],:], span_end.squeeze(-1)[[i]])
                
                output["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy()))
                output["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy()))
        return output