def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, get_sample_level_information=True) -> Dict[str, torch.Tensor]: """ WE LOAD THE MODELS ONE INTO GPU ONE AT A TIME !!! """ subresults = [] for submodel in self.submodels: submodel.to(device=submodel.cf_a.device) subres = submodel(question, passage, span_start, span_end, metadata, get_sample_level_information) submodel.to(device=torch.device("cpu")) subresults.append(subres) batch_size = len(subresults[0]["best_span"]) best_span = merge_span_probs(subresults) output = { "best_span": best_span, "best_span_str": [], "models_output": subresults } if (get_sample_level_information): output["em_samples"] = [] output["f1_samples"] = [] for index in range(batch_size): if metadata is not None: passage_str = metadata[index]['original_passage'] offsets = metadata[index]['token_offsets'] predicted_span = tuple(best_span[index].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output["best_span_str"].append(best_span_string) answer_texts = metadata[index].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) if (get_sample_level_information): em_sample, f1_sample = bidut.get_em_f1_metrics( best_span_string, answer_texts) output["em_samples"].append(em_sample) output["f1_samples"].append(f1_sample) if (get_sample_level_information): # Add information about the individual samples for future analysis output["span_start_sample_loss"] = [] output["span_end_sample_loss"] = [] for i in range(batch_size): span_start_probs = sum( subresult['span_start_probs'] for subresult in subresults) / len(subresults) span_end_probs = sum( subresult['span_end_probs'] for subresult in subresults) / len(subresults) span_start_loss = nll_loss(span_start_probs[[i], :], span_start.squeeze(-1)[[i]]) span_end_loss = nll_loss(span_end_probs[[i], :], span_end.squeeze(-1)[[i]]) output["span_start_sample_loss"].append( float(span_start_loss.detach().cpu().numpy())) output["span_end_sample_loss"].append( float(span_end_loss.detach().cpu().numpy())) return output
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, get_sample_level_information = False, get_attentions = False) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ """ #################### Sample Bayesian weights ################## """ self.sample_posterior() """ ################## MASK COMPUTING ######################## """ question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None """ ###################### EMBEDDING + HIGHWAY LAYER ######################## """ # self.cf_a.use_ELMO if(self.cf_a.Add_Linear_projection_ELMO): embedded_question = self._highway_layer(self._linear_projection_ELMO (self._text_field_embedder(question['character_ids'])["elmo_representations"][-1])) embedded_passage = self._highway_layer(self._linear_projection_ELMO(self._text_field_embedder(passage['character_ids'])["elmo_representations"][-1])) else: embedded_question = self._highway_layer(self._text_field_embedder(question['character_ids'])["elmo_representations"][-1]) embedded_passage = self._highway_layer(self._text_field_embedder(passage['character_ids'])["elmo_representations"][-1]) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) """ ###################### phrase_layer LAYER ######################## """ encoded_question = self._dropout_phrase_layer(self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout_phrase_layer(self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) """ ###################### Attention LAYER ######################## """ # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector], dim=-1) modeled_passage = self._dropout_modeling_passage(self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) """ ###################### Spans LAYER ######################## """ # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout_spans_output(torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout_span_end_encode(self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout_spans_output(torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = bidut.get_best_span(span_start_logits, span_end_logits) output_dict = { "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) loss = span_start_loss + span_end_loss self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss output_dict["span_start_loss"] = span_start_loss output_dict["span_end_loss"] = span_end_loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: if (get_sample_level_information): output_dict["em_samples"] = [] output_dict["f1_samples"] = [] output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) if (get_sample_level_information): em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts) output_dict["em_samples"].append(em_sample) output_dict["f1_samples"].append(f1_sample) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens if (get_sample_level_information): # Add information about the individual samples for future analysis output_dict["span_start_sample_loss"] = [] output_dict["span_end_sample_loss"] = [] for i in range (batch_size): span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits[[i],:], passage_mask[[i],:]), span_start.squeeze(-1)[[i]]) span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits[[i],:], passage_mask[[i],:]), span_end.squeeze(-1)[[i]]) output_dict["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy())) output_dict["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy())) if(get_attentions): output_dict["C2Q_attention"] = passage_question_attention output_dict["Q2C_attention"] = question_passage_attention output_dict["simmilarity"] = passage_question_similarity return output_dict
def forward_ensemble(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, get_sample_level_information = False) -> Dict[str, torch.Tensor]: """ Sample 10 times and add them together """ self.set_posterior_mean(True) most_likely_output = self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information) self.set_posterior_mean(False) subresults = [most_likely_output] for i in range(10): subresults.append(self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information)) batch_size = len(subresults[0]["best_span"]) best_span = bidut.merge_span_probs(subresults) output = { "best_span": best_span, "best_span_str": [], "models_output": subresults } if (get_sample_level_information): output["em_samples"] = [] output["f1_samples"] = [] for index in range(batch_size): if metadata is not None: passage_str = metadata[index]['original_passage'] offsets = metadata[index]['token_offsets'] predicted_span = tuple(best_span[index].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output["best_span_str"].append(best_span_string) answer_texts = metadata[index].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) if (get_sample_level_information): em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts) output["em_samples"].append(em_sample) output["f1_samples"].append(f1_sample) if (get_sample_level_information): # Add information about the individual samples for future analysis output["span_start_sample_loss"] = [] output["span_end_sample_loss"] = [] for i in range (batch_size): span_start_probs = sum(subresult['span_start_probs'] for subresult in subresults) / len(subresults) span_end_probs = sum(subresult['span_end_probs'] for subresult in subresults) / len(subresults) span_start_loss = nll_loss(span_start_probs[[i],:], span_start.squeeze(-1)[[i]]) span_end_loss = nll_loss(span_end_probs[[i],:], span_end.squeeze(-1)[[i]]) output["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy())) output["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy())) return output
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, get_sample_level_information=False, get_attentions=False) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ """ #################### Sample Bayesian weights ################## """ self.sample_posterior() """ ################## MASK COMPUTING ######################## """ question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None """ ###################### EMBEDDING + HIGHWAY LAYER ######################## """ # self.cf_a.use_ELMO if (self.cf_a.Add_Linear_projection_ELMO): embedded_question = self._highway_layer( self._linear_projection_ELMO( self._text_field_embedder(question['character_ids']) ["elmo_representations"][-1])) embedded_passage = self._highway_layer( self._linear_projection_ELMO( self._text_field_embedder( passage['character_ids'])["elmo_representations"][-1])) else: embedded_question = self._highway_layer( self._text_field_embedder( question['character_ids'])["elmo_representations"][-1]) embedded_passage = self._highway_layer( self._text_field_embedder( passage['character_ids'])["elmo_representations"][-1]) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) """ ###################### phrase_layer LAYER ######################## """ encoded_question = self._dropout_phrase_layer( self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout_phrase_layer( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) """ ###################### Attention LAYER ######################## """ # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax( passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector ], dim=-1) modeled_passage = self._dropout_modeling_passage( self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) """ ###################### Spans LAYER ######################## """ # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout_spans_output( torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze( 1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([ final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation ], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout_span_end_encode( self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout_spans_output( torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = bidut.get_best_span(span_start_logits, span_end_logits) output_dict = { "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: span_start_loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) span_end_loss = nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) loss = span_start_loss + span_end_loss self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss output_dict["span_start_loss"] = span_start_loss output_dict["span_end_loss"] = span_end_loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: if (get_sample_level_information): output_dict["em_samples"] = [] output_dict["f1_samples"] = [] output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) if (get_sample_level_information): em_sample, f1_sample = bidut.get_em_f1_metrics( best_span_string, answer_texts) output_dict["em_samples"].append(em_sample) output_dict["f1_samples"].append(f1_sample) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens if (get_sample_level_information): # Add information about the individual samples for future analysis output_dict["span_start_sample_loss"] = [] output_dict["span_end_sample_loss"] = [] for i in range(batch_size): span_start_loss = nll_loss( util.masked_log_softmax(span_start_logits[[i], :], passage_mask[[i], :]), span_start.squeeze(-1)[[i]]) span_end_loss = nll_loss( util.masked_log_softmax(span_end_logits[[i], :], passage_mask[[i], :]), span_end.squeeze(-1)[[i]]) output_dict["span_start_sample_loss"].append( float(span_start_loss.detach().cpu().numpy())) output_dict["span_end_sample_loss"].append( float(span_end_loss.detach().cpu().numpy())) if (get_attentions): output_dict["C2Q_attention"] = passage_question_attention output_dict["Q2C_attention"] = question_passage_attention output_dict["simmilarity"] = passage_question_similarity return output_dict
def forward_ensemble( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, get_sample_level_information=False) -> Dict[str, torch.Tensor]: """ Sample 10 times and add them together """ self.set_posterior_mean(True) most_likely_output = self.forward(question, passage, span_start, span_end, metadata, get_sample_level_information) self.set_posterior_mean(False) subresults = [most_likely_output] for i in range(10): subresults.append( self.forward(question, passage, span_start, span_end, metadata, get_sample_level_information)) batch_size = len(subresults[0]["best_span"]) best_span = bidut.merge_span_probs(subresults) output = { "best_span": best_span, "best_span_str": [], "models_output": subresults } if (get_sample_level_information): output["em_samples"] = [] output["f1_samples"] = [] for index in range(batch_size): if metadata is not None: passage_str = metadata[index]['original_passage'] offsets = metadata[index]['token_offsets'] predicted_span = tuple(best_span[index].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output["best_span_str"].append(best_span_string) answer_texts = metadata[index].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) if (get_sample_level_information): em_sample, f1_sample = bidut.get_em_f1_metrics( best_span_string, answer_texts) output["em_samples"].append(em_sample) output["f1_samples"].append(f1_sample) if (get_sample_level_information): # Add information about the individual samples for future analysis output["span_start_sample_loss"] = [] output["span_end_sample_loss"] = [] for i in range(batch_size): span_start_probs = sum( subresult['span_start_probs'] for subresult in subresults) / len(subresults) span_end_probs = sum( subresult['span_end_probs'] for subresult in subresults) / len(subresults) span_start_loss = nll_loss(span_start_probs[[i], :], span_start.squeeze(-1)[[i]]) span_end_loss = nll_loss(span_end_probs[[i], :], span_end.squeeze(-1)[[i]]) output["span_start_sample_loss"].append( float(span_start_loss.detach().cpu().numpy())) output["span_end_sample_loss"].append( float(span_end_loss.detach().cpu().numpy())) return output
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, get_sample_level_information = True) -> Dict[str, torch.Tensor]: """ WE LOAD THE MODELS ONE INTO GPU ONE AT A TIME !!! """ subresults = [] for submodel in self.submodels: submodel.to(device = submodel.cf_a.device) subres = submodel(question, passage, span_start, span_end, metadata, get_sample_level_information) submodel.to(device = torch.device("cpu")) subresults.append(subres) batch_size = len(subresults[0]["best_span"]) best_span = merge_span_probs(subresults) output = { "best_span": best_span, "best_span_str": [], "models_output": subresults } if (get_sample_level_information): output["em_samples"] = [] output["f1_samples"] = [] for index in range(batch_size): if metadata is not None: passage_str = metadata[index]['original_passage'] offsets = metadata[index]['token_offsets'] predicted_span = tuple(best_span[index].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output["best_span_str"].append(best_span_string) answer_texts = metadata[index].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) if (get_sample_level_information): em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts) output["em_samples"].append(em_sample) output["f1_samples"].append(f1_sample) if (get_sample_level_information): # Add information about the individual samples for future analysis output["span_start_sample_loss"] = [] output["span_end_sample_loss"] = [] for i in range (batch_size): span_start_probs = sum(subresult['span_start_probs'] for subresult in subresults) / len(subresults) span_end_probs = sum(subresult['span_end_probs'] for subresult in subresults) / len(subresults) span_start_loss = nll_loss(span_start_probs[[i],:], span_start.squeeze(-1)[[i]]) span_end_loss = nll_loss(span_end_probs[[i],:], span_end.squeeze(-1)[[i]]) output["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy())) output["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy())) return output