def test_accuracy_computation(self): accuracy = BooleanAccuracy() predictions = torch.Tensor([[0, 1], [2, 3], [4, 5], [6, 7]]) targets = torch.Tensor([[0, 1], [2, 2], [4, 5], [7, 7]]) accuracy(predictions, targets) assert accuracy.get_metric() == 2. / 4 mask = torch.ones(4, 2) mask[1, 1] = 0 accuracy(predictions, targets, mask) assert accuracy.get_metric() == 5. / 8 targets[1, 1] = 3 accuracy(predictions, targets) assert accuracy.get_metric() == 8. / 12 accuracy.reset() accuracy(predictions, targets) assert accuracy.get_metric() == 3. / 4
span_accuracy_function = BooleanAccuracy() squad_metrics_function = SquadEmAndF1() # Compute the loss for training. if span_start is not None: span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) loss = span_start_loss + span_end_loss span_start_accuracy_function(span_start_logits, span_start.squeeze(-1)) span_end_accuracy_function(span_end_logits, span_end.squeeze(-1)) span_accuracy_function(best_span, torch.stack([span_start, span_end], -1)) span_start_accuracy = span_start_accuracy_function.get_metric() span_end_accuracy = span_end_accuracy_function.get_metric() span_accuracy = span_accuracy_function.get_metric() print ("Loss: ", loss) print ("span_start_accuracy: ", span_start_accuracy) print ("span_start_accuracy: ", span_start_accuracy) print ("span_end_accuracy: ", span_end_accuracy) # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: best_span_str = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens'])
def test_does_not_divide_by_zero_with_no_count(self, device: str): accuracy = BooleanAccuracy() assert accuracy.get_metric() == pytest.approx(0.0)
class DialogQA(Model): """ This class implements modified version of BiDAF (with self attention and residual layer, from Clark and Gardner ACL 17 paper) model as used in Question Answering in Context (EMNLP 2018) paper [https://arxiv.org/pdf/1808.07036.pdf]. In this set-up, a single instance is a dialog, list of question answer pairs. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. span_start_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span end predictions into the passage state. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). num_context_answers : ``int``, optional (default=0) If greater than 0, the model will consider previous question answering context. max_span_length: ``int``, optional (default=0) Maximum token length of the output span. max_turn_length: ``int``, optional (default=12) Maximum length of an interaction. """ def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, phrase_layer: Seq2SeqEncoder, residual_encoder: Seq2SeqEncoder, span_start_encoder: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, initializer: Optional[InitializerApplicator] = None, dropout: float = 0.2, num_context_answers: int = 0, marker_embedding_dim: int = 10, max_span_length: int = 30, max_turn_length: int = 12, ) -> None: super().__init__(vocab) self._num_context_answers = num_context_answers self._max_span_length = max_span_length self._text_field_embedder = text_field_embedder self._phrase_layer = phrase_layer self._marker_embedding_dim = marker_embedding_dim self._encoding_dim = phrase_layer.get_output_dim() self._matrix_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, "x,y,x*y") self._merge_atten = TimeDistributed( torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim)) self._residual_encoder = residual_encoder if num_context_answers > 0: self._question_num_marker = torch.nn.Embedding( max_turn_length, marker_embedding_dim * num_context_answers) self._prev_ans_marker = torch.nn.Embedding( (num_context_answers * 4) + 1, marker_embedding_dim) self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, "x,y,x*y") self._followup_lin = torch.nn.Linear(self._encoding_dim, 3) self._merge_self_attention = TimeDistributed( torch.nn.Linear(self._encoding_dim * 3, self._encoding_dim)) self._span_start_encoder = span_start_encoder self._span_end_encoder = span_end_encoder self._span_start_predictor = TimeDistributed( torch.nn.Linear(self._encoding_dim, 1)) self._span_end_predictor = TimeDistributed( torch.nn.Linear(self._encoding_dim, 1)) self._span_yesno_predictor = TimeDistributed( torch.nn.Linear(self._encoding_dim, 3)) self._span_followup_predictor = TimeDistributed(self._followup_lin) check_dimensions_match( phrase_layer.get_input_dim(), text_field_embedder.get_output_dim() + marker_embedding_dim * num_context_answers, "phrase layer input dim", "embedding dim + marker dim * num context answers", ) if initializer is not None: initializer(self) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_yesno_accuracy = CategoricalAccuracy() self._span_followup_accuracy = CategoricalAccuracy() self._span_gt_yesno_accuracy = CategoricalAccuracy() self._span_gt_followup_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._official_f1 = Average() self._variational_dropout = InputVariationalDropout(dropout) def forward( # type: ignore self, question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, p1_answer_marker: torch.IntTensor = None, p2_answer_marker: torch.IntTensor = None, p3_answer_marker: torch.IntTensor = None, yesno_list: torch.IntTensor = None, followup_list: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. p1_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 0. This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length]. Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>. For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac p2_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 1. It is similar to p1_answer_marker, but marking previous previous answer in passage. p3_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 2. It is similar to p1_answer_marker, but marking previous previous previous answer in passage. yesno_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (the yes/no/not a yes no question). followup_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (followup / maybe followup / don't followup). metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of the followings. Each of the followings is a nested list because first iterates over dialog, then questions in dialog. qid : List[List[str]] A list of list, consisting of question ids. followup : List[List[int]] A list of list, consisting of continuation marker prediction index. (y :yes, m: maybe follow up, n: don't follow up) yesno : List[List[int]] A list of list, consisting of affirmation marker prediction index. (y :yes, x: not a yes/no question, n: np) best_span_str : List[List[str]] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ token_character_ids = question["token_characters"]["token_characters"] batch_size, max_qa_count, max_q_len, _ = token_character_ids.size() total_qa_count = batch_size * max_qa_count qa_mask = torch.ge(followup_list, 0).view(total_qa_count) embedded_question = self._text_field_embedder(question, num_wrapping_dims=1) embedded_question = embedded_question.reshape( total_qa_count, max_q_len, self._text_field_embedder.get_output_dim()) embedded_question = self._variational_dropout(embedded_question) embedded_passage = self._variational_dropout( self._text_field_embedder(passage)) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question, num_wrapping_dims=1) question_mask = question_mask.reshape(total_qa_count, max_q_len) passage_mask = util.get_text_field_mask(passage) repeated_passage_mask = passage_mask.unsqueeze(1).repeat( 1, max_qa_count, 1) repeated_passage_mask = repeated_passage_mask.view( total_qa_count, passage_length) if self._num_context_answers > 0: # Encode question turn number inside the dialog into question embedding. question_num_ind = util.get_range_vector( max_qa_count, util.get_device_of(embedded_question)) question_num_ind = question_num_ind.unsqueeze(-1).repeat( 1, max_q_len) question_num_ind = question_num_ind.unsqueeze(0).repeat( batch_size, 1, 1) question_num_ind = question_num_ind.reshape( total_qa_count, max_q_len) question_num_marker_emb = self._question_num_marker( question_num_ind) embedded_question = torch.cat( [embedded_question, question_num_marker_emb], dim=-1) # Encode the previous answers in passage embedding. repeated_embedded_passage = (embedded_passage.unsqueeze(1).repeat( 1, max_qa_count, 1, 1).view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim())) # batch_size * max_qa_count, passage_length, word_embed_dim p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length) p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker) repeated_embedded_passage = torch.cat( [repeated_embedded_passage, p1_answer_marker_emb], dim=-1) if self._num_context_answers > 1: p2_answer_marker = p2_answer_marker.view( total_qa_count, passage_length) p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker) repeated_embedded_passage = torch.cat( [repeated_embedded_passage, p2_answer_marker_emb], dim=-1) if self._num_context_answers > 2: p3_answer_marker = p3_answer_marker.view( total_qa_count, passage_length) p3_answer_marker_emb = self._prev_ans_marker( p3_answer_marker) repeated_embedded_passage = torch.cat( [repeated_embedded_passage, p3_answer_marker_emb], dim=-1) repeated_encoded_passage = self._variational_dropout( self._phrase_layer(repeated_embedded_passage, repeated_passage_mask)) else: encoded_passage = self._variational_dropout( self._phrase_layer(embedded_passage, passage_mask)) repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat( 1, max_qa_count, 1, 1) repeated_encoded_passage = repeated_encoded_passage.view( total_qa_count, passage_length, self._encoding_dim) encoded_question = self._variational_dropout( self._phrase_layer(embedded_question, question_mask)) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_similarity = self._matrix_attention( repeated_encoded_passage, encoded_question) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_attention = util.masked_softmax( passage_question_similarity, question_mask) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) question_passage_attention = util.masked_softmax( question_passage_similarity, repeated_passage_mask) # Shape: (batch_size * max_qa_count, encoding_dim) question_passage_vector = util.weighted_sum( repeated_encoded_passage, question_passage_attention) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(total_qa_count, passage_length, self._encoding_dim) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4) final_merged_passage = torch.cat( [ repeated_encoded_passage, passage_question_vectors, repeated_encoded_passage * passage_question_vectors, repeated_encoded_passage * tiled_question_passage_vector, ], dim=-1, ) final_merged_passage = F.relu(self._merge_atten(final_merged_passage)) residual_layer = self._variational_dropout( self._residual_encoder(final_merged_passage, repeated_passage_mask)) self_attention_matrix = self._self_attention(residual_layer, residual_layer) mask = repeated_passage_mask.reshape( total_qa_count, passage_length, 1) * repeated_passage_mask.reshape( total_qa_count, 1, passage_length) self_mask = torch.eye(passage_length, passage_length, dtype=torch.bool, device=self_attention_matrix.device) self_mask = self_mask.reshape(1, passage_length, passage_length) mask = mask & ~self_mask self_attention_probs = util.masked_softmax(self_attention_matrix, mask) # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim) self_attention_vecs = torch.matmul(self_attention_probs, residual_layer) self_attention_vecs = torch.cat([ self_attention_vecs, residual_layer, residual_layer * self_attention_vecs ], dim=-1) residual_layer = F.relu( self._merge_self_attention(self_attention_vecs)) final_merged_passage = final_merged_passage + residual_layer # batch_size * maxqa_pair_len * max_passage_len * 200 final_merged_passage = self._variational_dropout(final_merged_passage) start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask) span_start_logits = self._span_start_predictor(start_rep).squeeze(-1) end_rep = self._span_end_encoder( torch.cat([final_merged_passage, start_rep], dim=-1), repeated_passage_mask) span_end_logits = self._span_end_predictor(end_rep).squeeze(-1) span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1) span_followup_logits = self._span_followup_predictor(end_rep).squeeze( -1) span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7) # batch_size * maxqa_len_pair, max_document_len span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7) best_span = self._get_best_span_yesno_followup( span_start_logits, span_end_logits, span_yesno_logits, span_followup_logits, self._max_span_length, ) output_dict: Dict[str, Any] = {} # Compute the loss. if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1), ignore_index=-1, ) self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask) loss += nll_loss( util.masked_log_softmax(span_end_logits, repeated_passage_mask), span_end.view(-1), ignore_index=-1, ) self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask) self._span_accuracy( best_span[:, 0:2], torch.stack([span_start, span_end], -1).view(total_qa_count, 2), mask=qa_mask.unsqueeze(1).expand(-1, 2), ) # add a select for the right span to compute loss gold_span_end_loc = [] span_end = span_end.view( total_qa_count).squeeze().data.cpu().numpy() for i in range(0, total_qa_count): gold_span_end_loc.append( max(span_end[i] * 3 + i * passage_length * 3, 0)) gold_span_end_loc.append( max(span_end[i] * 3 + i * passage_length * 3 + 1, 0)) gold_span_end_loc.append( max(span_end[i] * 3 + i * passage_length * 3 + 2, 0)) gold_span_end_loc = span_start.new(gold_span_end_loc) pred_span_end_loc = [] for i in range(0, total_qa_count): pred_span_end_loc.append( max(best_span[i][1] * 3 + i * passage_length * 3, 0)) pred_span_end_loc.append( max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0)) pred_span_end_loc.append( max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0)) predicted_end = span_start.new(pred_span_end_loc) _yesno = span_yesno_logits.view(-1).index_select( 0, gold_span_end_loc).view(-1, 3) _followup = span_followup_logits.view(-1).index_select( 0, gold_span_end_loc).view(-1, 3) loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1) loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1) _yesno = span_yesno_logits.view(-1).index_select( 0, predicted_end).view(-1, 3) _followup = span_followup_logits.view(-1).index_select( 0, predicted_end).view(-1, 3) self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask) self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask) output_dict["loss"] = loss # Compute F1 and preparing the output dictionary. output_dict["best_span_str"] = [] output_dict["qid"] = [] output_dict["followup"] = [] output_dict["yesno"] = [] best_span_cpu = best_span.detach().cpu().numpy() for i in range(batch_size): passage_str = metadata[i]["original_passage"] offsets = metadata[i]["token_offsets"] f1_score = 0.0 per_dialog_best_span_list = [] per_dialog_yesno_list = [] per_dialog_followup_list = [] per_dialog_query_id_list = [] for per_dialog_query_index, (iid, answer_texts) in enumerate( zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])): predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index]) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] yesno_pred = predicted_span[2] followup_pred = predicted_span[3] per_dialog_yesno_list.append(yesno_pred) per_dialog_followup_list.append(followup_pred) per_dialog_query_id_list.append(iid) best_span_string = passage_str[start_offset:end_offset] per_dialog_best_span_list.append(best_span_string) if answer_texts: if len(answer_texts) > 1: t_f1 = [] # Compute F1 over N-1 human references and averages the scores. for answer_index in range(len(answer_texts)): idxes = list(range(len(answer_texts))) idxes.pop(answer_index) refs = [answer_texts[z] for z in idxes] t_f1.append( squad.metric_max_over_ground_truths( squad.f1_score, best_span_string, refs)) f1_score = 1.0 * sum(t_f1) / len(t_f1) else: f1_score = squad.metric_max_over_ground_truths( squad.f1_score, best_span_string, answer_texts) self._official_f1(100 * f1_score) output_dict["qid"].append(per_dialog_query_id_list) output_dict["best_span_str"].append(per_dialog_best_span_list) output_dict["yesno"].append(per_dialog_yesno_list) output_dict["followup"].append(per_dialog_followup_list) return output_dict @overrides def make_output_human_readable( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: yesno_tags = [[ self.vocab.get_token_from_index(x, namespace="yesno_labels") for x in yn_list ] for yn_list in output_dict.pop("yesno")] followup_tags = [[ self.vocab.get_token_from_index(x, namespace="followup_labels") for x in followup_list ] for followup_list in output_dict.pop("followup")] output_dict["yesno"] = yesno_tags output_dict["followup"] = followup_tags return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { "start_acc": self._span_start_accuracy.get_metric(reset), "end_acc": self._span_end_accuracy.get_metric(reset), "span_acc": self._span_accuracy.get_metric(reset), "yesno": self._span_yesno_accuracy.get_metric(reset), "followup": self._span_followup_accuracy.get_metric(reset), "f1": self._official_f1.get_metric(reset), } @staticmethod def _get_best_span_yesno_followup( span_start_logits: torch.Tensor, span_end_logits: torch.Tensor, span_yesno_logits: torch.Tensor, span_followup_logits: torch.Tensor, max_span_length: int, ) -> torch.Tensor: # Returns the index of highest-scoring span that is not longer than 30 tokens, as well as # yesno prediction bit and followup prediction bit from the predicted span end token. if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 4), dtype=torch.long) span_start_logits = span_start_logits.data.cpu().numpy() span_end_logits = span_end_logits.data.cpu().numpy() span_yesno_logits = span_yesno_logits.data.cpu().numpy() span_followup_logits = span_followup_logits.data.cpu().numpy() for b_i in range(batch_size): for j in range(passage_length): val1 = span_start_logits[b_i, span_start_argmax[b_i]] if val1 < span_start_logits[b_i, j]: span_start_argmax[b_i] = j val1 = span_start_logits[b_i, j] val2 = span_end_logits[b_i, j] if val1 + val2 > max_span_log_prob[b_i]: if j - span_start_argmax[b_i] > max_span_length: continue best_word_span[b_i, 0] = span_start_argmax[b_i] best_word_span[b_i, 1] = j max_span_log_prob[b_i] = val1 + val2 for b_i in range(batch_size): j = best_word_span[b_i, 1] yesno_pred = np.argmax(span_yesno_logits[b_i, j]) followup_pred = np.argmax(span_followup_logits[b_i, j]) best_word_span[b_i, 2] = int(yesno_pred) best_word_span[b_i, 3] = int(followup_pred) return best_word_span
class BertQA(Model): """ This class implements Minjoon Seo's `Bidirectional Attention Flow model <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_ for answering reading comprehension questions (ICLR 2017). The basic layout is pretty simple: encode words as a combination of word embeddings and a character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of attentions to put question information into the passage word representations (this is the only part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and do a softmax over span start and span end. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. num_highway_layers : ``int`` The number of highway layers to use in between embedding the input and passing it through the phrase layer. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. similarity_function : ``SimilarityFunction`` The similarity function that we will use when comparing encoded passage and question representations. modeling_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). mask_lstms : ``bool``, optional (default=True) If ``False``, we will skip passing the mask to the LSTM layers. This gives a ~2x speedup, with only a slight performance decrease, if any. We haven't experimented much with this yet, but have confirmed that we still get very similar performance with much faster training times. We still use the mask for all softmaxes, but avoid the shuffling that's required when using masking with pytorch LSTMs. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, sim_text_field_embedder: TextFieldEmbedder, loss_weights: Dict, sim_class_weights: List, pretrained_sim_path: str = None, use_scenario_encoding: bool = True, sim_pretraining: bool = False, dropout: float = 0.2, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(BertQA, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder if use_scenario_encoding: self._sim_text_field_embedder = sim_text_field_embedder self.loss_weights = loss_weights self.sim_class_weights = sim_class_weights self.use_scenario_encoding = use_scenario_encoding self.sim_pretraining = sim_pretraining if self.sim_pretraining and not self.use_scenario_encoding: raise ValueError( "When pretraining Scenario Interpretation Module, you should use it." ) embedding_dim = self._text_field_embedder.get_output_dim() self._action_predictor = torch.nn.Linear(embedding_dim, 4) self._sim_token_label_predictor = torch.nn.Linear(embedding_dim, 4) self._span_predictor = torch.nn.Linear(embedding_dim, 2) self._action_accuracy = CategoricalAccuracy() self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() self._span_loss_metric = Average() self._action_loss_metric = Average() self._sim_loss_metric = Average() self._sim_yes_f1 = F1Measure(2) self._sim_no_f1 = F1Measure(3) if use_scenario_encoding and pretrained_sim_path is not None: logger.info("Loading pretrained model..") self.load_state_dict(torch.load(pretrained_sim_path)) for param in self._sim_text_field_embedder.parameters(): param.requires_grad = False if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x initializer(self) def get_passage_representation(self, bert_output, bert_input): # Shape: (batch_size, bert_input_len) input_type_ids = self.get_input_type_ids( bert_input['bert-type-ids'], bert_input['bert-offsets'], self._text_field_embedder._token_embedders['bert']).float() # Shape: (batch_size, bert_input_len) input_mask = util.get_text_field_mask(bert_input).float() passage_mask = input_mask - input_type_ids # works only with one [SEP] # Shape: (batch_size, bert_input_len, embedding_dim) passage_representation = bert_output * passage_mask.unsqueeze(2) # Shape: (batch_size, passage_len, embedding_dim) passage_representation = passage_representation[:, passage_mask.sum( dim=0) > 0, :] # Shape: (batch_size, passage_len) passage_mask = passage_mask[:, passage_mask.sum(dim=0) > 0] return passage_representation, passage_mask def forward( self, # type: ignore bert_input: Dict[str, torch.LongTensor], sim_bert_input: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question tokens, passage tokens, original passage text, and token offsets into the passage for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ if self.use_scenario_encoding: # Shape: (batch_size, sim_bert_input_len_wp) sim_bert_input_token_labels_wp = sim_bert_input[ 'scenario_gold_encoding'] # Shape: (batch_size, sim_bert_input_len_wp, embedding_dim) sim_bert_output_wp = self._sim_text_field_embedder(sim_bert_input) # Shape: (batch_size, sim_bert_input_len_wp) sim_input_mask_wp = (sim_bert_input['bert'] != 0).float() # Shape: (batch_size, sim_bert_input_len_wp) sim_passage_mask_wp = sim_input_mask_wp - sim_bert_input[ 'bert-type-ids'].float() # works only with one [SEP] # Shape: (batch_size, sim_bert_input_len_wp, embedding_dim) sim_passage_representation_wp = sim_bert_output_wp * sim_passage_mask_wp.unsqueeze( 2) # Shape: (batch_size, passage_len_wp, embedding_dim) sim_passage_representation_wp = sim_passage_representation_wp[:, sim_passage_mask_wp .sum( dim =0 ) > 0, :] # Shape: (batch_size, passage_len_wp) sim_passage_token_labels_wp = sim_bert_input_token_labels_wp[:, sim_passage_mask_wp .sum( dim =0 ) > 0] # Shape: (batch_size, passage_len_wp) sim_passage_mask_wp = sim_passage_mask_wp[:, sim_passage_mask_wp.sum( dim=0) > 0] # Shape: (batch_size, passage_len_wp, 4) sim_token_logits_wp = self._sim_token_label_predictor( sim_passage_representation_wp) if span_start is not None: # during training and validation class_weights = torch.tensor(self.sim_class_weights, device=sim_token_logits_wp.device, dtype=torch.float) sim_loss = cross_entropy(sim_token_logits_wp.view(-1, 4), sim_passage_token_labels_wp.view(-1), ignore_index=0, weight=class_weights) self._sim_loss_metric(sim_loss.item()) self._sim_yes_f1(sim_token_logits_wp, sim_passage_token_labels_wp, sim_passage_mask_wp) self._sim_no_f1(sim_token_logits_wp, sim_passage_token_labels_wp, sim_passage_mask_wp) if self.sim_pretraining: return {'loss': sim_loss} if not self.sim_pretraining: # Shape: (batch_size, passage_len_wp) bert_input['scenario_encoding'] = (sim_token_logits_wp.argmax( dim=2)) * sim_passage_mask_wp.long() # Shape: (batch_size, bert_input_len_wp) bert_input_wp_len = bert_input['history_encoding'].size(1) if bert_input['scenario_encoding'].size(1) > bert_input_wp_len: # Shape: (batch_size, bert_input_len_wp) bert_input['scenario_encoding'] = bert_input[ 'scenario_encoding'][:, :bert_input_wp_len] else: batch_size = bert_input['scenario_encoding'].size(0) difference = bert_input_wp_len - bert_input[ 'scenario_encoding'].size(1) zeros = torch.zeros( batch_size, difference, dtype=bert_input['scenario_encoding'].dtype, device=bert_input['scenario_encoding'].device) # Shape: (batch_size, bert_input_len_wp) bert_input['scenario_encoding'] = torch.cat( [bert_input['scenario_encoding'], zeros], dim=1) # Shape: (batch_size, bert_input_len + 1, embedding_dim) bert_output = self._text_field_embedder(bert_input) # Shape: (batch_size, embedding_dim) pooled_output = bert_output[:, 0] # Shape: (batch_size, bert_input_len, embedding_dim) bert_output = bert_output[:, 1:, :] # Shape: (batch_size, passage_len, embedding_dim), (batch_size, passage_len) passage_representation, passage_mask = self.get_passage_representation( bert_output, bert_input) # Shape: (batch_size, 4) action_logits = self._action_predictor(pooled_output) # Shape: (batch_size, passage_len, 2) span_logits = self._span_predictor(passage_representation) # Shape: (batch_size, passage_len, 1), (batch_size, passage_len, 1) span_start_logits, span_end_logits = span_logits.split(1, dim=2) # Shape: (batch_size, passage_len) span_start_logits = span_start_logits.squeeze(2) # Shape: (batch_size, passage_len) span_end_logits = span_end_logits.squeeze(2) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = get_best_span(span_start_logits, span_end_logits) output_dict = { "pooled_output": pooled_output, "passage_representation": passage_representation, "action_logits": action_logits, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } if self.use_scenario_encoding: output_dict["sim_token_logits"] = sim_token_logits_wp # Compute the loss for training (and for validation) if span_start is not None: # Shape: (batch_size,) span_loss = nll_loss(util.masked_log_softmax( span_start_logits, passage_mask), span_start.squeeze(1), reduction='none') # Shape: (batch_size,) span_loss += nll_loss(util.masked_log_softmax( span_end_logits, passage_mask), span_end.squeeze(1), reduction='none') # Shape: (batch_size,) more_mask = (label == self.vocab.get_token_index( 'More', namespace="labels")).float() # Shape: (batch_size,) span_loss = (span_loss * more_mask).sum() / (more_mask.sum() + 1e-6) if more_mask.sum() > 1e-7: self._span_start_accuracy(span_start_logits, span_start.squeeze(1), more_mask) self._span_end_accuracy(span_end_logits, span_end.squeeze(1), more_mask) # Shape: (batch_size, 2) span_acc_mask = more_mask.unsqueeze(1).expand(-1, 2).long() self._span_accuracy(best_span, torch.cat([span_start, span_end], dim=1), span_acc_mask) action_loss = cross_entropy(action_logits, label) self._action_accuracy(action_logits, label) self._span_loss_metric(span_loss.item()) self._action_loss_metric(action_loss.item()) output_dict['loss'] = self.loss_weights[ 'span_loss'] * span_loss + self.loss_weights[ 'action_loss'] * action_loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if not self.training: # true during validation and test output_dict['best_span_str'] = [] batch_size = len(metadata) for i in range(batch_size): passage_text = metadata[i]['passage_text'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_str = passage_text[start_offset:end_offset] output_dict['best_span_str'].append(best_span_str) if 'gold_span' in metadata[i]: if metadata[i]['action'] == 'More': gold_span = metadata[i]['gold_span'] self._squad_metrics(best_span_str, [gold_span]) return output_dict def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: action_probs = softmax(output_dict['action_logits'], dim=1) output_dict['action_probs'] = action_probs predictions = action_probs.cpu().data.numpy() argmax_indices = numpy.argmax(predictions, axis=1) labels = [ self.vocab.get_token_from_index(x, namespace="labels") for x in argmax_indices ] output_dict['label'] = labels return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: if self.use_scenario_encoding: sim_loss = self._sim_loss_metric.get_metric(reset) _, _, yes_f1 = self._sim_yes_f1.get_metric(reset) _, _, no_f1 = self._sim_no_f1.get_metric(reset) if self.sim_pretraining: return {'sim_macro_f1': (yes_f1 + no_f1) / 2} try: action_acc = self._action_accuracy.get_metric(reset) except ZeroDivisionError: action_acc = 0 try: start_acc = self._span_start_accuracy.get_metric(reset) except ZeroDivisionError: start_acc = 0 try: end_acc = self._span_end_accuracy.get_metric(reset) except ZeroDivisionError: end_acc = 0 try: span_acc = self._span_accuracy.get_metric(reset) except ZeroDivisionError: span_acc = 0 exact_match, f1_score = self._squad_metrics.get_metric(reset) span_loss = self._span_loss_metric.get_metric(reset) action_loss = self._action_loss_metric.get_metric(reset) agg_metric = span_acc + action_acc * 0.45 metrics = { 'action_acc': action_acc, 'span_acc': span_acc, 'span_loss': span_loss, 'action_loss': action_loss, 'agg_metric': agg_metric } if self.use_scenario_encoding: metrics['sim_macro_f1'] = (yes_f1 + no_f1) / 2 if not self.training: # during validation metrics['em'] = exact_match metrics['f1'] = f1_score return metrics @staticmethod def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor: # We call the inputs "logits" - they could either be unnormalized logits or normalized log # probabilities. A log_softmax operation is a constant shifting of the entire logit # vector, so taking an argmax over either one gives the same result. if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() device = span_start_logits.device # (batch_size, passage_length, passage_length) span_log_probs = span_start_logits.unsqueeze( 2) + span_end_logits.unsqueeze(1) # Only the upper triangle of the span matrix is valid; the lower triangle has entries where # the span ends before it starts. span_log_mask = torch.triu( torch.ones((passage_length, passage_length), device=device)).log().unsqueeze(0) valid_span_log_probs = span_log_probs + span_log_mask # Here we take the span matrix and flatten it, then find the best span using argmax. We # can recover the start and end indices from this flattened list using simple modular # arithmetic. # (batch_size, passage_length * passage_length) best_spans = valid_span_log_probs.view(batch_size, -1).argmax(-1) span_start_indices = best_spans // passage_length span_end_indices = best_spans % passage_length return torch.stack([span_start_indices, span_end_indices], dim=-1) def get_input_type_ids(self, type_ids, offsets, embedder): "Converts (bsz, seq_len_wp) to (bsz, seq_len_wp) by indexing." batch_size = type_ids.size(0) full_seq_len = type_ids.size(1) if full_seq_len > embedder.max_pieces: # Recombine if we had used sliding window approach assert batch_size == 1 and type_ids.max() > 0 num_question_tokens = type_ids[0][:embedder.max_pieces].nonzero( ).size(0) select_indices = embedder.indices_to_select( full_seq_len, num_question_tokens) type_ids = type_ids[:, select_indices] range_vector = util.get_range_vector( batch_size, device=util.get_device_of(type_ids)).unsqueeze(1) type_ids = type_ids[range_vector, offsets] return type_ids
class QaNet(Model): """ This class implements Adams Wei Yu's `QANet Model <https://openreview.net/forum?id=B14TlG-RW>`_ for machine reading comprehension published at ICLR 2018. The overall architecture of QANet is very similar to BiDAF. The main difference is that QANet replaces the RNN encoder with CNN + self-attention. There are also some minor differences in the modeling layer and output layer. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. num_highway_layers : ``int`` The number of highway layers to use in between embedding the input and passing it through the phrase layer. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the passage-question attention. matrix_attention_layer : ``MatrixAttention`` The matrix attention function that we will use when comparing encoded passage and question representations. modeling_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end. dropout_prob : ``float``, optional (default=0.1) If greater than 0, we will apply dropout with this probability between layers. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, matrix_attention_layer: MatrixAttention, modeling_layer: Seq2SeqEncoder, dropout_prob: float = 0.1, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None, ) -> None: super().__init__(vocab, regularizer) text_embed_dim = text_field_embedder.get_output_dim() encoding_in_dim = phrase_layer.get_input_dim() encoding_out_dim = phrase_layer.get_output_dim() modeling_in_dim = modeling_layer.get_input_dim() modeling_out_dim = modeling_layer.get_output_dim() self._text_field_embedder = text_field_embedder self._embedding_proj_layer = torch.nn.Linear(text_embed_dim, encoding_in_dim) self._highway_layer = Highway(encoding_in_dim, num_highway_layers) self._encoding_proj_layer = torch.nn.Linear(encoding_in_dim, encoding_in_dim) self._phrase_layer = phrase_layer self._matrix_attention = matrix_attention_layer self._modeling_proj_layer = torch.nn.Linear(encoding_out_dim * 4, modeling_in_dim) self._modeling_layer = modeling_layer self._span_start_predictor = torch.nn.Linear(modeling_out_dim * 2, 1) self._span_end_predictor = torch.nn.Linear(modeling_out_dim * 2, 1) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._metrics = SquadEmAndF1() self._dropout = torch.nn.Dropout(p=dropout_prob) if dropout_prob > 0 else lambda x: x initializer(self) def forward( # type: ignore self, question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question tokens, passage tokens, original passage text, and token offsets into the passage for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ question_mask = util.get_text_field_mask(question) passage_mask = util.get_text_field_mask(passage) embedded_question = self._dropout(self._text_field_embedder(question)) embedded_passage = self._dropout(self._text_field_embedder(passage)) embedded_question = self._highway_layer(self._embedding_proj_layer(embedded_question)) embedded_passage = self._highway_layer(self._embedding_proj_layer(embedded_passage)) batch_size = embedded_question.size(0) projected_embedded_question = self._encoding_proj_layer(embedded_question) projected_embedded_passage = self._encoding_proj_layer(embedded_passage) encoded_question = self._dropout( self._phrase_layer(projected_embedded_question, question_mask) ) encoded_passage = self._dropout( self._phrase_layer(projected_embedded_passage, passage_mask) ) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = masked_softmax( passage_question_similarity, question_mask, memory_efficient=True ) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # Shape: (batch_size, question_length, passage_length) question_passage_attention = masked_softmax( passage_question_similarity.transpose(1, 2), passage_mask, memory_efficient=True ) # Shape: (batch_size, passage_length, passage_length) attention_over_attention = torch.bmm(passage_question_attention, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) passage_passage_vectors = util.weighted_sum(encoded_passage, attention_over_attention) # Shape: (batch_size, passage_length, encoding_dim * 4) merged_passage_attention_vectors = self._dropout( torch.cat( [ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * passage_passage_vectors, ], dim=-1, ) ) modeled_passage_list = [self._modeling_proj_layer(merged_passage_attention_vectors)] for _ in range(3): modeled_passage = self._dropout( self._modeling_layer(modeled_passage_list[-1], passage_mask) ) modeled_passage_list.append(modeled_passage) # Shape: (batch_size, passage_length, modeling_dim * 2)) span_start_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length, modeling_dim * 2) span_end_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e32) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e32) # Shape: (batch_size, passage_length) span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1) span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1) best_span = get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1) ) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1) ) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.cat([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict["best_span_str"] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]["question_tokens"]) passage_tokens.append(metadata[i]["passage_tokens"]) passage_str = metadata[i]["original_passage"] offsets = metadata[i]["token_offsets"] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict["best_span_str"].append(best_span_string) answer_texts = metadata[i].get("answer_texts", []) if answer_texts: self._metrics(best_span_string, answer_texts) output_dict["question_tokens"] = question_tokens output_dict["passage_tokens"] = passage_tokens return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._metrics.get_metric(reset) return { "start_acc": self._span_start_accuracy.get_metric(reset), "end_acc": self._span_end_accuracy.get_metric(reset), "span_acc": self._span_accuracy.get_metric(reset), "em": exact_match, "f1": f1_score, }
class RobertaSpanPredictionModel(Model): """ """ def __init__(self, vocab: Vocabulary, pretrained_model: str = None, requires_grad: bool = True, transformer_weights_model: str = None, layer_freeze_regexes: List[str] = None, on_load: bool = False, regularizer: Optional[RegularizerApplicator] = None) -> None: super().__init__(vocab, regularizer) if on_load: logging.info(f"Skipping loading of initial Transformer weights") transformer_config = RobertaConfig.from_pretrained( pretrained_model) self._transformer_model = RobertaModel(transformer_config) elif transformer_weights_model: logging.info( f"Loading Transformer weights model from {transformer_weights_model}" ) transformer_model_loaded = load_archive(transformer_weights_model) self._transformer_model = transformer_model_loaded.model._transformer_model else: self._transformer_model = RobertaModel.from_pretrained( pretrained_model) for name, param in self._transformer_model.named_parameters(): grad = requires_grad if layer_freeze_regexes and grad: grad = not any( [bool(re.search(r, name)) for r in layer_freeze_regexes]) param.requires_grad = grad transformer_config = self._transformer_model.config num_labels = 2 # For start/end self.qa_outputs = Linear(transformer_config.hidden_size, num_labels) # Import GTP2 machinery to get from tokens to actual text self.byte_decoder = {v: k for k, v in bytes_to_unicode().items()} self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() self._debug = 2 self._padding_value = 1 # The index of the RoBERTa padding token def forward(self, tokens: Dict[str, torch.LongTensor], segment_ids: torch.LongTensor = None, start_positions: torch.LongTensor = None, end_positions: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> torch.Tensor: self._debug -= 1 input_ids = tokens['tokens'] batch_size = input_ids.size(0) num_choices = input_ids.size(1) tokens_mask = (input_ids != self._padding_value).long() if self._debug > 0: print(f"batch_size = {batch_size}") print(f"num_choices = {num_choices}") print(f"tokens_mask = {tokens_mask}") print(f"input_ids.size() = {input_ids.size()}") print(f"input_ids = {input_ids}") print(f"segment_ids = {segment_ids}") print(f"start_positions = {start_positions}") print(f"end_positions = {end_positions}") # Segment ids are not used by RoBERTa transformer_outputs = self._transformer_model( input_ids=input_ids, # token_type_ids=segment_ids, attention_mask=tokens_mask) sequence_output = transformer_outputs[0] logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) span_start_logits = util.replace_masked_values(start_logits, tokens_mask, -1e7) span_end_logits = util.replace_masked_values(end_logits, tokens_mask, -1e7) best_span = get_best_span(span_start_logits, span_end_logits) span_start_probs = util.masked_softmax(span_start_logits, tokens_mask) span_end_probs = util.masked_softmax(span_end_logits, tokens_mask) output_dict = { "start_logits": start_logits, "end_logits": end_logits, "best_span": best_span } output_dict["start_probs"] = span_start_probs output_dict["end_probs"] = span_end_probs if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions.clamp_(0, ignored_index) end_positions.clamp_(0, ignored_index) self._span_start_accuracy(span_start_logits, start_positions) self._span_end_accuracy(span_end_logits, end_positions) self._span_accuracy( best_span, torch.cat([ start_positions.unsqueeze(-1), end_positions.unsqueeze(-1) ], -1)) loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignored_index) # Should we mask out invalid positions here? start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 output_dict["loss"] = total_loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] output_dict['exact_match'] = [] output_dict['f1_score'] = [] tokens_texts = [] for i in range(batch_size): tokens_text = metadata[i]['tokens'] tokens_texts.append(tokens_text) predicted_span = tuple(best_span[i].detach().cpu().numpy()) predicted_start = predicted_span[0] predicted_end = predicted_span[1] predicted_tokens = tokens_text[predicted_start:(predicted_end + 1)] best_span_string = self.convert_tokens_to_string( predicted_tokens) output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) exact_match = 0 f1_score = 0 if answer_texts: exact_match, f1_score = self._squad_metrics( best_span_string, answer_texts) output_dict['exact_match'].append(exact_match) output_dict['f1_score'].append(f1_score) output_dict['tokens_texts'] = tokens_texts if self._debug > 0: print(f"output_dict = {output_dict}") return output_dict def convert_tokens_to_string(self, tokens): """ Converts a sequence of tokens (string) in a single string. """ text = ''.join(tokens) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace') return text def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } @classmethod def _load(cls, config: Params, serialization_dir: str, weights_file: str = None, cuda_device: int = -1, **kwargs) -> 'Model': model_params = config.get('model') model_params.update({"on_load": True}) config.update({'model': model_params}) return super()._load(config=config, serialization_dir=serialization_dir, weights_file=weights_file, cuda_device=cuda_device, **kwargs)
class BidafPlusSelfAttention(Model): def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, phrase_layer: Seq2SeqEncoder, residual_encoder: Seq2SeqEncoder, span_start_encoder: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, initializer: InitializerApplicator, dropout: float = 0.2, mask_lstms: bool = True) -> None: super(BidafPlusSelfAttention, self).__init__(vocab) self._text_field_embedder = text_field_embedder self._phrase_layer = phrase_layer self._matrix_attention = TriLinearAttention(200) self._merge_atten = TimeDistributed(torch.nn.Linear(200 * 4, 200)) self._residual_encoder = residual_encoder self._self_atten = TriLinearAttention(200) self._merge_self_atten = TimeDistributed(torch.nn.Linear(200 * 3, 200)) self._span_start_encoder = span_start_encoder self._span_end_encoder = span_end_encoder self._span_start_predictor = TimeDistributed(torch.nn.Linear(200, 1)) self._span_end_predictor = TimeDistributed(torch.nn.Linear(200, 1)) initializer(self) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._official_em = Average() self._official_f1 = Average() if dropout > 0: # self._dropout = torch.nn.Dropout(p=dropout) self._dropout = VariationalDropout(p=dropout) else: raise ValueError() # self._dropout = lambda x: x self._mask_lstms = mask_lstms def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)``. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._dropout(self._text_field_embedder(question)) embedded_passage = self._dropout(self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector], dim=-1) final_merged_passage = F.relu(self._merge_atten(final_merged_passage)) residual_layer = self._dropout(self._residual_encoder(self._dropout(final_merged_passage), passage_mask)) self_atten_matrix = self._self_atten(residual_layer, residual_layer) mask = passage_mask.resize(batch_size, passage_length, 1) * passage_mask.resize(batch_size, 1, passage_length) # torch.eye does not have a gpu implementation, so we are forced to use the cpu one and .cuda() # Not sure if this matters for performance self_mask = Variable(torch.eye(passage_length, passage_length).cuda()).resize(1, passage_length, passage_length) mask = mask * (1 - self_mask) self_atten_probs = util.last_dim_softmax(self_atten_matrix, mask) # Batch matrix multiplication: # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim) self_atten_vecs = torch.matmul(self_atten_probs, residual_layer) residual_layer = F.relu(self._merge_self_atten(torch.cat( [self_atten_vecs, residual_layer, residual_layer * self_atten_vecs], dim=-1))) final_merged_passage += residual_layer final_merged_passage = self._dropout(final_merged_passage) start_rep = self._span_start_encoder(final_merged_passage, passage_lstm_mask) span_start_logits = self._span_start_predictor(start_rep).squeeze(-1) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1), passage_lstm_mask) span_end_logits = self._span_end_predictor(end_rep).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self._get_best_span(span_start_logits, span_end_logits) output_dict = {"span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span} if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss if metadata is not None: output_dict['best_span_str'] = [] for i in range(batch_size): passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].data.cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) exact_match = f1_score = 0 if answer_texts: exact_match = squad_eval.metric_max_over_ground_truths( squad_eval.exact_match_score, best_span_string, answer_texts) f1_score = squad_eval.metric_max_over_ground_truths( squad_eval.f1_score, best_span_string, answer_texts) self._official_em(100 * exact_match) self._official_f1(100 * f1_score) return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': self._official_em.get_metric(reset), 'f1': self._official_f1.get_metric(reset), } @staticmethod def _get_best_span(span_start_logits: Variable, span_end_logits: Variable) -> Variable: if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError("Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = Variable(span_start_logits.data.new() .resize_(batch_size, 2).fill_(0)).long() span_start_logits = span_start_logits.data.cpu().numpy() span_end_logits = span_end_logits.data.cpu().numpy() for b in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b, span_start_argmax[b]] if val1 < span_start_logits[b, j]: span_start_argmax[b] = j val1 = span_start_logits[b, j] val2 = span_end_logits[b, j] if val1 + val2 > max_span_log_prob[b]: best_word_span[b, 0] = span_start_argmax[b] best_word_span[b, 1] = j max_span_log_prob[b] = val1 + val2 return best_word_span @classmethod def from_params(cls, vocab: Vocabulary, params: Params) -> 'BidirectionalAttentionFlow': embedder_params = params.pop("text_field_embedder") text_field_embedder = TextFieldEmbedder.from_params(vocab, embedder_params) phrase_layer = Seq2SeqEncoder.from_params(params.pop("phrase_layer")) residual_encoder = Seq2SeqEncoder.from_params(params.pop("residual_encoder")) span_start_encoder = Seq2SeqEncoder.from_params(params.pop("span_start_encoder")) span_end_encoder = Seq2SeqEncoder.from_params(params.pop("span_end_encoder")) initializer = InitializerApplicator.from_params(params.pop("initializer", [])) dropout = params.pop('dropout', 0.2) # TODO: Remove the following when fully deprecated evaluation_json_file = params.pop('evaluation_json_file', None) if evaluation_json_file is not None: logger.warning("the 'evaluation_json_file' model parameter is deprecated, please remove") mask_lstms = params.pop('mask_lstms', True) params.assert_empty(cls.__name__) return cls(vocab=vocab, text_field_embedder=text_field_embedder, phrase_layer=phrase_layer, residual_encoder=residual_encoder, span_start_encoder=span_start_encoder, span_end_encoder=span_end_encoder, initializer=initializer, dropout=dropout, mask_lstms=mask_lstms)
class TransformerQA(Model): """ This class implements a reading comprehension model patterned after the proposed model in https://arxiv.org/abs/1810.04805 (Devlin et al), with improvements borrowed from the SQuAD model in the transformers project. It predicts start tokens and end tokens with a linear layer on top of word piece embeddings. Note that the metrics that the model produces are calculated on a per-instance basis only. Since there could be more than one instance per question, these metrics are not the official numbers on the SQuAD task. To get official numbers, run the script in scripts/transformer_qa_eval.py. Parameters ---------- vocab : ``Vocabulary`` transformer_model_name : ``str``, optional (default=``bert-base-cased``) This model chooses the embedder according to this setting. You probably want to make sure this is set to the same thing as the reader. """ def __init__(self, vocab: Vocabulary, transformer_model_name: str = "bert-base-cased", hidden_size=768, **kwargs) -> None: super().__init__(vocab, **kwargs) self._text_field_embedder = BasicTextFieldEmbedder({ "tokens": PretrainedTransformerEmbedder(transformer_model_name, hidden_size=hidden_size, task="QA") }) self._linear_layer = nn.Linear( self._text_field_embedder.get_output_dim(), 2) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._per_instance_metrics = SquadEmAndF1() def forward( # type: ignore self, question_with_context: Dict[str, Dict[str, torch.LongTensor]], context_span: torch.IntTensor, answer_span: Optional[torch.IntTensor] = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ Parameters ---------- question_with_context : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this text field contains the context followed by the question. It further assumes that the tokens have type ids set such that any token that can be part of the answer (i.e., tokens from the context) has type id 0, and any other token (including [CLS] and [SEP]) has type id 1. context_span : ``torch.IntTensor`` From a ``SpanField``. This marks the span of word pieces in ``question`` from which answers can come. answer_span : ``torch.IntTensor``, optional From a ``SpanField``. This is the thing we are trying to predict - the span of text that marks the answer. If given, we compute a loss that gets included in the output directory. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question id, and the original texts of context, question, tokenized version of both, and a list of possible answers. The length of the ``metadata`` list should be the batch size, and each dictionary should have the keys ``id``, ``question``, ``context``, ``question_tokens``, ``context_tokens``, and ``answers``. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. best_span_scores : torch.FloatTensor The score for each of the best spans. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._text_field_embedder(question_with_context) logits = self._linear_layer(embedded_question) span_start_logits, span_end_logits = logits.split(1, dim=-1) span_start_logits = span_start_logits.squeeze(-1) span_end_logits = span_end_logits.squeeze(-1) possible_answer_mask = torch.zeros_like( get_token_ids_from_text_field_tensors(question_with_context), dtype=torch.bool) for i, (start, end) in enumerate(context_span): possible_answer_mask[i, start:end + 1] = True span_start_logits = util.replace_masked_values(span_start_logits, possible_answer_mask, -1e32) span_end_logits = util.replace_masked_values(span_end_logits, possible_answer_mask, -1e32) span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1) span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1) best_spans = get_best_span(span_start_logits, span_end_logits) best_span_scores = torch.gather( span_start_logits, 1, best_spans[:, 0].unsqueeze(1)) + torch.gather( span_end_logits, 1, best_spans[:, 1].unsqueeze(1)) best_span_scores = best_span_scores.squeeze(1) output_dict = { "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_spans, "best_span_scores": best_span_scores, } # Compute the loss for training. if answer_span is not None: span_start = answer_span[:, 0] span_end = answer_span[:, 1] span_mask = span_start != -1 self._span_accuracy(best_spans, answer_span, span_mask.unsqueeze(-1).expand_as(best_spans)) start_loss = cross_entropy(span_start_logits, span_start, ignore_index=-1) if torch.any(start_loss > 1e9): logger.critical("Start loss too high (%r)", start_loss) logger.critical("span_start_logits: %r", span_start_logits) logger.critical("span_start: %r", span_start) assert False end_loss = cross_entropy(span_end_logits, span_end, ignore_index=-1) if torch.any(end_loss > 1e9): logger.critical("End loss too high (%r)", end_loss) logger.critical("span_end_logits: %r", span_end_logits) logger.critical("span_end: %r", span_end) assert False loss = (start_loss + end_loss) / 2 self._span_start_accuracy(span_start_logits, span_start, span_mask) self._span_end_accuracy(span_end_logits, span_end, span_mask) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: best_spans = best_spans.detach().cpu().numpy() output_dict["best_span_str"] = [] context_tokens = [] for metadata_entry, best_span in zip(metadata, best_spans): context_tokens_for_question = metadata_entry["context_tokens"] context_tokens.append(context_tokens_for_question) best_span -= 1 + len(metadata_entry["question_tokens"]) + 2 assert np.all(best_span >= 0) predicted_start, predicted_end = tuple(best_span) while (predicted_start >= 0 and context_tokens_for_question[predicted_start].idx is None): predicted_start -= 1 if predicted_start < 0: logger.warning( f"Could not map the token '{context_tokens_for_question[best_span[0]].text}' at index " f"'{best_span[0]}' to an offset in the original text.") character_start = 0 else: character_start = context_tokens_for_question[ predicted_start].idx while (predicted_end < len(context_tokens_for_question) and context_tokens_for_question[predicted_end].idx is None): predicted_end += 1 if predicted_end >= len(context_tokens_for_question): logger.warning( f"Could not map the token '{context_tokens_for_question[best_span[1]].text}' at index " f"'{best_span[1]}' to an offset in the original text.") character_end = len(metadata_entry["context"]) else: end_token = context_tokens_for_question[predicted_end] character_end = end_token.idx + len( sanitize_wordpiece(end_token.text)) best_span_string = metadata_entry["context"][ character_start:character_end] output_dict["best_span_str"].append(best_span_string) answers = metadata_entry.get("answers") if len(answers) > 0: self._per_instance_metrics(best_span_string, answers) output_dict["context_tokens"] = context_tokens return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._per_instance_metrics.get_metric(reset) return { "start_acc": self._span_start_accuracy.get_metric(reset), "end_acc": self._span_end_accuracy.get_metric(reset), "span_acc": self._span_accuracy.get_metric(reset), "per_instance_em": exact_match, "per_instance_f1": f1_score, }
def main( gpu: int, qa_model_path: str, paragraphs_source: str, generated_decompositions_paths: Optional[str], data: str, output_predictions_file: str, output_metrics_file: str, overrides="{}", ): import_module_and_submodules("src") overrides_dict = {} overrides_dict.update(json.loads(overrides)) archive = load_archive(qa_model_path, cuda_device=gpu, overrides=json.dumps(overrides_dict)) predictor = Predictor.from_archive(archive) dataset_reader = StrategyQAReader( paragraphs_source=paragraphs_source, generated_decompositions_paths=generated_decompositions_paths, ) accuracy = BooleanAccuracy() last_logged_scores_time = time.monotonic() logger.info("Reading the dataset:") logger.info("Reading file at %s", data) dataset = None with open(data, mode="r", encoding="utf-8") as dataset_file: dataset = json.load(dataset_file) output_dataset = [] for json_obj in tqdm(dataset): item = dataset_reader.json_to_item(json_obj) decomposition = item["decomposition"] generated_decomposition = item["generated_decomposition"] gold_answer = torch.tensor(item["answer"]).view((1,)) used_decomposition = deepcopy( generated_decomposition if "generated_decomposition" in paragraphs_source else decomposition ) # Per instance: # Until the final step has an answer, find in each iteration # all of the steps that are required to answer the last step (including by proxy) # and don't have references in them. # If it is not possible, return a score of zero for the instance. # If it is possible, retrieve paragraphs for these steps, # and then pass the step and the paragraphs for it to be answered by the model. # Replace the answer in all of the steps that has a reference for it. step_answers = [None for i in range(len(used_decomposition))] while True: reachability = get_reachability([step["question"] for step in used_decomposition]) if reachability is None: break if step_answers[-1] is not None: break indices_of_interest = [] if (sum(reachability[-1])) != 0: for i, reachable in enumerate(reachability[-1]): if reachable > 0 and sum(reachability[i]) == 0: indices_of_interest.append(i) else: indices_of_interest.append(len(step_answers) - 1) paragraphs = dataset_reader.get_paragraphs( decomposition=[used_decomposition[i] for i in indices_of_interest], ) if paragraphs is not None: paragraphs_per_step_of_interest = paragraphs["per_step"] else: paragraphs_per_step_of_interest = [[{"content": " "}] for i in indices_of_interest] for i in indices_of_interest: step_answers[i] = get_answer( predictor=predictor, question=used_decomposition[i]["question"], paragraphs=paragraphs_per_step_of_interest[indices_of_interest.index(i)], force_yes_no=i == len(step_answers) - 1, ) # Return the best non-empty answer for i, step in enumerate(used_decomposition): used_decomposition[i]["question"] = fill_in_references( step["question"], step_answers ) predicted_answer_str = step_answers[-1].lower() if step_answers[-1] is not None else None if predicted_answer_str == "yes" or predicted_answer_str == "no": # Valid answer, the metric should be updated accordingly predicted_answer = torch.tensor(predicted_answer_str == "yes").view((1,)) accuracy(predicted_answer, gold_answer) else: # Invalid answer, the metric should be updated with a mistake accuracy(not gold_answer, gold_answer) if time.monotonic() - last_logged_scores_time > 3: metrics_dict = {"accuracy": accuracy.get_metric()} logger.info(json.dumps(metrics_dict)) last_logged_scores_time = time.monotonic() output_json_obj = deepcopy(json_obj) output_json_obj["decomposition"] = [step["question"] for step in used_decomposition] output_json_obj["step_answers"] = step_answers output_dataset.append(output_json_obj) if output_predictions_file is not None: with open(output_predictions_file, "w", encoding="utf-8") as f: json.dump(output_dataset, f, ensure_ascii=False, indent=4) metrics_dict = {"accuracy": accuracy.get_metric(reset=True)} if output_metrics_file is None: print(json.dumps(metrics_dict)) else: with open(output_metrics_file, "w", encoding="utf-8") as f: json.dump( metrics_dict, f, ensure_ascii=False, indent=4, )
class ESIMCosine(Model): def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, similarity_function: SimilarityFunction, projection_feedforward: FeedForward, inference_encoder: Seq2SeqEncoder, output_feedforward: FeedForwardPair, dropout: float = 0.5, margin: float = 1.25, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super().__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._encoder = encoder self._matrix_attention = LegacyMatrixAttention(similarity_function) self._projection_feedforward = projection_feedforward self._inference_encoder = inference_encoder if dropout: self.dropout = torch.nn.Dropout(dropout) self.rnn_input_dropout = InputVariationalDropout(dropout) else: self.dropout = None self.rnn_input_dropout = None self._output_feedforward = output_feedforward self._margin = margin self._accuracy = BooleanAccuracy() initializer(self) @overrides def forward( self, # type: ignore premise: Dict[str, torch.LongTensor], hypothesis: Dict[str, torch.LongTensor], label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: # Shape: (batch_size, seq_length, embedding_dim) embedded_premise = self._text_field_embedder(premise) embedded_hypothesis = self._text_field_embedder(hypothesis) mask_premise = get_text_field_mask(premise).float() mask_hypothesis = get_text_field_mask(hypothesis).float() # apply dropout for LSTM if self.rnn_input_dropout: embedded_premise = self.rnn_input_dropout(embedded_premise) embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis) # encode premise and hypothesis # Shape: (batch_size, seq_length, encoding_direction_num * encoding_hidden_dim) encoded_premise = self._encoder(embedded_premise, mask_premise) encoded_hypothesis = self._encoder(embedded_hypothesis, mask_hypothesis) # Shape: (batch_size, p_length, h_length) similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis) # Shape: (batch_size, p_length, h_length) p2h_attention = masked_softmax(similarity_matrix, mask_hypothesis) # Shape: (batch_size, p_length, encoding_direction_num * encoding_hidden_dim) attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention) # Shape: (batch_size, h_length, p_length) h2p_attention = masked_softmax( similarity_matrix.transpose(1, 2).contiguous(), mask_premise) # Shape: (batch_size, h_length, encoding_direction_num * encoding_hidden_dim) attended_premise = weighted_sum(encoded_premise, h2p_attention) # the "enhancement" layer # Shape: (batch_size, p_length, encoding_direction_num * encoding_hidden_dim * 4 + num_perspective * num_matching) enhanced_premise = torch.cat([ encoded_premise, attended_hypothesis, encoded_premise - attended_hypothesis, encoded_premise * attended_hypothesis ], dim=-1) # Shape: (batch_size, h_length, encoding_direction_num * encoding_hidden_dim * 4 + num_perspective * num_matching) enhanced_hypothesis = torch.cat([ encoded_hypothesis, attended_premise, encoded_hypothesis - attended_premise, encoded_hypothesis * attended_premise ], dim=-1) # The projection layer down to the model dimension. Dropout is not applied before # projection. # Shape: (batch_size, seq_length, projection_hidden_dim) projected_enhanced_premise = self._projection_feedforward( enhanced_premise) projected_enhanced_hypothesis = self._projection_feedforward( enhanced_hypothesis) # Run the inference layer if self.rnn_input_dropout: projected_enhanced_premise = self.rnn_input_dropout( projected_enhanced_premise) projected_enhanced_hypothesis = self.rnn_input_dropout( projected_enhanced_hypothesis) # Shape: (batch_size, seq_length, inference_direction_num * inference_hidden_dim) inferenced_premise = self._inference_encoder( projected_enhanced_premise, mask_premise) inferenced_hypothesis = self._inference_encoder( projected_enhanced_hypothesis, mask_hypothesis) # The pooling layer -- max and avg pooling. # Shape: (batch_size, inference_direction_num * inference_hidden_dim) pooled_premise_max, _ = replace_masked_values( inferenced_premise, mask_premise.unsqueeze(-1), -1e7).max(dim=1) pooled_hypothesis_max, _ = replace_masked_values( inferenced_hypothesis, mask_hypothesis.unsqueeze(-1), -1e7).max(dim=1) pooled_premise_avg = torch.sum( inferenced_premise * mask_premise.unsqueeze(-1), dim=1) / torch.sum(mask_premise, 1, keepdim=True) pooled_hypothesis_avg = torch.sum( inferenced_hypothesis * mask_hypothesis.unsqueeze(-1), dim=1) / torch.sum(mask_hypothesis, 1, keepdim=True) # Now concat # Shape: (batch_size, inference_direction_num * inference_hidden_dim * 2) pooled_premise_all = torch.cat( [pooled_premise_avg, pooled_premise_max], dim=1) pooled_hypothesis_all = torch.cat( [pooled_hypothesis_avg, pooled_hypothesis_max], dim=1) # the final MLP -- apply dropout to input, and MLP applies to output & hidden if self.dropout: pooled_premise_all = self.dropout(pooled_premise_all) pooled_hypothesis_all = self.dropout(pooled_hypothesis_all) # Shape: (batch_size, output_feedforward_hidden_dim) output_premise, output_hypothesis = self._output_feedforward( pooled_premise_all, pooled_hypothesis_all) distance = F.pairwise_distance(output_premise, output_hypothesis) prediction = distance < (self._margin / 2.0) output_dict = {'distance': distance, "prediction": prediction} if label is not None: """ Contrastive loss function. Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf """ y = label.float() l1 = y * torch.pow(distance, 2) / 2.0 l2 = (1 - y) * torch.pow( torch.clamp(self._margin - distance, min=0.0), 2) / 2.0 loss = torch.mean(l1 + l2) self._accuracy(prediction, label.byte()) output_dict["loss"] = loss return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: return {'accuracy': self._accuracy.get_metric(reset)}
class BidirectionalAttentionFlow(Model): """ This class implements Minjoon Seo's `Bidirectional Attention Flow model <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_ for answering reading comprehension questions (ICLR 2017). The basic layout is pretty simple: encode words as a combination of word embeddings and a character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of attentions to put question information into the passage word representations (this is the only part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and do a softmax over span start and span end. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. num_highway_layers : ``int`` The number of highway layers to use in between embedding the input and passing it through the phrase layer. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. attention_similarity_function : ``SimilarityFunction`` The similarity function that we will use when comparing encoded passage and question representations. modeling_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. initializer : ``InitializerApplicator`` We will use this to initialize the parameters in the model, calling ``initializer(self)``. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). mask_lstms : ``bool``, optional (default=True) If ``False``, we will skip passing the mask to the LSTM layers. This gives a ~2x speedup, with only a slight performance decrease, if any. We haven't experimented much with this yet, but have confirmed that we still get very similar performance with much faster training times. We still use the mask for all softmaxes, but avoid the shuffling that's required when using masking with pytorch LSTMs. evaluation_json_file : ``str``, optional If given, we will load this JSON into memory and use it to compute official metrics against. We need this separately from the validation dataset, because the official metrics use all of the annotations, while our dataset reader picks the most frequent one. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, attention_similarity_function: SimilarityFunction, modeling_layer: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, initializer: InitializerApplicator, dropout: float = 0.2, mask_lstms: bool = True) -> None: super(BidirectionalAttentionFlow, self).__init__(vocab) self._text_field_embedder = text_field_embedder self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim(), num_highway_layers)) self._phrase_layer = phrase_layer self._matrix_attention = MatrixAttention(attention_similarity_function) self._modeling_layer = modeling_layer self._span_end_encoder = span_end_encoder encoding_dim = phrase_layer.get_output_dim() modeling_dim = modeling_layer.get_output_dim() span_start_input_dim = encoding_dim * 4 + modeling_dim self._span_start_predictor = TimeDistributed(torch.nn.Linear(span_start_input_dim, 1)) span_end_encoding_dim = span_end_encoder.get_output_dim() span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim self._span_end_predictor = TimeDistributed(torch.nn.Linear(span_end_input_dim, 1)) initializer(self) # Bidaf has lots of layer dimensions which need to match up - these # aren't necessarily obvious from the configuration files, so we check # here. if modeling_layer.get_input_dim() != 4 * encoding_dim: raise ConfigurationError("The input dimension to the modeling_layer must be " "equal to 4 times the encoding dimension of the phrase_layer. " "Found {} and 4 * {} respectively.".format(modeling_layer.get_input_dim(), encoding_dim)) if text_field_embedder.get_output_dim() != phrase_layer.get_input_dim(): raise ConfigurationError("The output dimension of the text_field_embedder (embedding_dim + " "char_cnn) must match the input dimension of the phrase_encoder. " "Found {} and {}, respectively.".format(text_field_embedder.get_output_dim(), phrase_layer.get_input_dim())) if span_end_encoder.get_input_dim() != encoding_dim * 4 + modeling_dim * 3: raise ConfigurationError("The input dimension of the span_end_encoder should be equal to " "4 * phrase_layer.output_dim + 3 * modeling_layer.output_dim. " "Found {} and (4 * {} + 3 * {}) " "respectively.".format(span_end_encoder.get_input_dim(), encoding_dim, modeling_dim)) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)``. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer(self._text_field_embedder(question)) embedded_passage = self._highway_layer(self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector], dim=-1) modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self._get_best_span(span_start_logits, span_end_logits) output_dict = {"span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span} if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss if metadata is not None: output_dict['best_span_str'] = [] for i in range(batch_size): passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].data.cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } @staticmethod def _get_best_span(span_start_logits: Variable, span_end_logits: Variable) -> Variable: if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError("Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = Variable(span_start_logits.data.new() .resize_(batch_size, 2).fill_(0)).long() span_start_logits = span_start_logits.data.cpu().numpy() span_end_logits = span_end_logits.data.cpu().numpy() for b in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b, span_start_argmax[b]] if val1 < span_start_logits[b, j]: span_start_argmax[b] = j val1 = span_start_logits[b, j] val2 = span_end_logits[b, j] if val1 + val2 > max_span_log_prob[b]: best_word_span[b, 0] = span_start_argmax[b] best_word_span[b, 1] = j max_span_log_prob[b] = val1 + val2 return best_word_span @classmethod def from_params(cls, vocab: Vocabulary, params: Params) -> 'BidirectionalAttentionFlow': embedder_params = params.pop("text_field_embedder") text_field_embedder = TextFieldEmbedder.from_params(vocab, embedder_params) num_highway_layers = params.pop("num_highway_layers") phrase_layer = Seq2SeqEncoder.from_params(params.pop("phrase_layer")) similarity_function = SimilarityFunction.from_params(params.pop("similarity_function")) modeling_layer = Seq2SeqEncoder.from_params(params.pop("modeling_layer")) span_end_encoder = Seq2SeqEncoder.from_params(params.pop("span_end_encoder")) initializer = InitializerApplicator.from_params(params.pop("initializer", [])) dropout = params.pop('dropout', 0.2) # TODO: Remove the following when fully deprecated evaluation_json_file = params.pop('evaluation_json_file', None) if evaluation_json_file is not None: logger.warning("the 'evaluation_json_file' model parameter is deprecated, please remove") mask_lstms = params.pop('mask_lstms', True) params.assert_empty(cls.__name__) return cls(vocab=vocab, text_field_embedder=text_field_embedder, num_highway_layers=num_highway_layers, phrase_layer=phrase_layer, attention_similarity_function=similarity_function, modeling_layer=modeling_layer, span_end_encoder=span_end_encoder, initializer=initializer, dropout=dropout, mask_lstms=mask_lstms)
class DocLevelmpeEsim(mpeEsim): def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, projection_feedforward: FeedForward, inference_encoder: Seq2SeqEncoder, output_feedforward: FeedForward, output_logit: FeedForward, final_feedforward: FeedForward, coverage_loss: CoverageLoss, similarity_function: SimilarityFunction = DotProductSimilarity(), dropout: float = 0.5, contextualize_pair_comparators: bool = False, pair_context_encoder: Seq2SeqEncoder = None, pair_feedforward: FeedForward = None, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: # Need to send it verbatim because otherwise FromParams doesn't work appropriately. super().__init__( vocab=vocab, text_field_embedder=text_field_embedder, encoder=encoder, similarity_function=similarity_function, projection_feedforward=projection_feedforward, inference_encoder=inference_encoder, output_feedforward=output_feedforward, output_logit=output_logit, final_feedforward=final_feedforward, contextualize_pair_comparators=contextualize_pair_comparators, coverage_loss=coverage_loss, pair_context_encoder=pair_context_encoder, pair_feedforward=pair_feedforward, dropout=dropout, initializer=initializer, regularizer=regularizer) self._answer_loss = torch.nn.BCELoss() self.max_sent_count = 120 self.fc1 = torch.nn.Linear(self.max_sent_count, 10) self.fc2 = torch.nn.Linear(10, 5) self.fc3 = torch.nn.Linear(5, 1) self.out_sigmoid = torch.nn.Sigmoid() self._accuracy = BooleanAccuracy() @overrides def forward( self, # type: ignore premises: Dict[str, torch.LongTensor], hypotheses: Dict[str, torch.LongTensor], paragraph: Dict[str, torch.LongTensor], answer_index: torch.LongTensor = None, relevance_presence_mask: torch.Tensor = None ) -> Dict[str, torch.Tensor]: hypothesis_list = unbind_tensor_dict(hypotheses, dim=1) label_logits = [] premises_attentions = [] premises_aggregation_attentions = [] #coverage_losses = [] for hypothesis in hypothesis_list: # single hypothesis even to the parent class #print("super().forward",len(premises), len(hypothesis), len(paragraph)) output_dict = super().forward(premises=premises, hypothesis=hypothesis, paragraph=paragraph) #paragraph? individual_logit = output_dict["label_logits"][:, self._label2idx[ "entailment"]] # only useful key label_logits.append(individual_logit) # premises_attention = output_dict.get("premises_attention", None) premises_attentions.append(premises_attention) premises_aggregation_attention = output_dict.get( "premises_aggregation_attention", None) premises_aggregation_attentions.append( premises_aggregation_attention) #if relevance_presence_mask is not None: #coverage_loss = output_dict["coverage_loss"] #coverage_losses.append(coverage_loss) del output_dict, individual_logit, premises_attention, premises_aggregation_attention label_logits = torch.stack(label_logits, dim=-1) premises_attentions = torch.stack(premises_attentions, dim=1) premises_aggregation_attentions = torch.stack( premises_aggregation_attentions, dim=1) #if relevance_presence_mask is not None: #coverage_losses = torch.stack(coverage_losses, dim=0) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) # @todo: Check covaraince of label_logits and label_probs if label_logits.shape[1] < self.max_sent_count: label_logits = torch.nn.functional.pad( input=label_logits, pad=(0, self.max_sent_count - label_logits.shape[1], 0, 0), mode='constant', value=0) single_output_logit = self.fc3(self.fc2(self.fc1(label_logits))) sigmoid_output = self.out_sigmoid(single_output_logit) #import pdb; pdb.set_trace() output_dict = { "label_logits": single_output_logit, "label_probs": sigmoid_output, "premises_attentions": premises_attentions, "premises_aggregation_attentions": premises_aggregation_attentions } if answer_index is not None: #print("_answer_loss",single_output_logit, answer_index) cudadevice = single_output_logit.device # torch.device('cuda:'+ str(single_output_logit.get_device())) temp_tensor = torch.tensor([[k] for k in answer_index]).to(cudadevice) sgd = torch.nn.Sigmoid() loss = self._answer_loss(sgd(single_output_logit), sgd(temp_tensor.float())) output_dict["loss"] = loss output_dict["novelty"] = (single_output_logit > 0.5) temp_tensor = torch.tensor([[k] for k in answer_index]) #print("_answer_loss",single_output_logit, temp_tensor) self._accuracy(single_output_logit > 0.5, temp_tensor.byte()) del temp_tensor, loss, cudadevice #self._accuracy(single_output_logit>0.5, answer_index) del label_logits, label_probs, hypothesis_list, # if answer_index is not None: # answer_loss # loss = self._answer_loss(label_logits, answer_index) # coverage loss # if relevance_presence_mask is not None: # loss += coverage_losses.mean() # output_dict["loss"] = loss # self._accuracy(label_logits, answer_index) return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: accuracy_metric = self._accuracy.get_metric(reset) return {'accuracy': accuracy_metric}
if span_start is not None: span_start_loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) span_end_loss = nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) loss = span_start_loss + span_end_loss span_start_accuracy_function(span_start_logits, span_start.squeeze(-1)) span_end_accuracy_function(span_end_logits, span_end.squeeze(-1)) span_accuracy_function(best_span, torch.stack([span_start, span_end], -1)) span_start_accuracy = span_start_accuracy_function.get_metric() span_end_accuracy = span_end_accuracy_function.get_metric() span_accuracy = span_accuracy_function.get_metric() print("Loss: ", loss) print("span_start_accuracy: ", span_start_accuracy) print("span_start_accuracy: ", span_start_accuracy) print("span_end_accuracy: ", span_end_accuracy) # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: best_span_str = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage']
class GMN(Model): def __init__(self, args, word_embeddings: TextFieldEmbedder, vocab: Vocabulary) -> None: super().__init__(vocab) # parameters self.args = args self.word_embeddings = word_embeddings # gate self.W_z = nn.Linear(self.args.embedding_size, 1, bias=False) self.U_z = nn.Linear(self.args.embedding_size, 1, bias=False) self.W_r = nn.Linear(self.args.embedding_size, 1, bias=False) self.U_r = nn.Linear(self.args.embedding_size, 1, bias=False) self.W = nn.Linear(self.args.embedding_size, 1, bias=False) self.U = nn.Linear(self.args.embedding_size, 1, bias=False) # layers self.event_embedding = EventEmbedding(args, self.word_embeddings) self.attention = Attention(self.args.embedding_size, score_function='mlp') self.sigmoid = Sigmoid() self.tanh = Tanh() self.score = Score(self.args.embedding_size, self.args.embedding_size, threshold=self.args.threshold) # metrics self.accuracy = BooleanAccuracy() self.f1_score = F1Measure(positive_label=1) self.loss_function = BCELoss() def gated_atten(self, vt_1, atten_input): """ gated attention block :param vt_1: v_t-1 :param atten_input: [h1, h2, ... ,h_n-1] :return: v_t """ # [batch_size, 1, embedding_size] out_at, _ = self.attention(atten_input, vt_1) # [batch_size, embedding_size] h_e = torch.sum(out_at * atten_input, dim=1) # [batch_size, 1] z = (self.sigmoid(self.W_z(h_e.unsqueeze(1)) + self.U_z(vt_1))).squeeze(1) # [batch_size, 1] r = (self.sigmoid(self.W_r(h_e.unsqueeze(1)) + self.U_r(vt_1))).squeeze(1) # [batch_size, 1] h = self.tanh( self.W(h_e.unsqueeze(1)) + self.U((torch.mul(r, vt_1.squeeze(1))).unsqueeze(1))).squeeze(1) # [baych_size, 1, embedding_size] vt = (torch.mul( (1 - z), vt_1.squeeze(1)) + torch.mul(z, h)).unsqueeze(1) return vt @overrides def forward(self, trigger_0: Dict[str, torch.LongTensor], trigger_agent_0: Dict[str, torch.LongTensor], agent_attri_0: Dict[str, torch.LongTensor], trigger_object_0: Dict[str, torch.LongTensor], object_attri_0: Dict[str, torch.LongTensor], trigger_1: Dict[str, torch.LongTensor], trigger_agent_1: Dict[str, torch.LongTensor], agent_attri_1: Dict[str, torch.LongTensor], trigger_object_1: Dict[str, torch.LongTensor], object_attri_1: Dict[str, torch.LongTensor], trigger_2: Dict[str, torch.LongTensor], trigger_agent_2: Dict[str, torch.LongTensor], agent_attri_2: Dict[str, torch.LongTensor], trigger_object_2: Dict[str, torch.LongTensor], object_attri_2: Dict[str, torch.LongTensor], trigger_3: Dict[str, torch.LongTensor], trigger_agent_3: Dict[str, torch.LongTensor], agent_attri_3: Dict[str, torch.LongTensor], trigger_object_3: Dict[str, torch.LongTensor], object_attri_3: Dict[str, torch.LongTensor], trigger_4: Dict[str, torch.LongTensor], trigger_agent_4: Dict[str, torch.LongTensor], agent_attri_4: Dict[str, torch.LongTensor], trigger_object_4: Dict[str, torch.LongTensor], object_attri_4: Dict[str, torch.LongTensor], event_type: Dict[str, torch.LongTensor], label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # tri, e: [batch_size, 1, embedding_size] tri0, e0 = self.event_embedding(trigger_0, trigger_agent_0, trigger_object_0) tri1, e1 = self.event_embedding(trigger_1, trigger_agent_1, trigger_object_1) tri2, e2 = self.event_embedding(trigger_2, trigger_agent_2, trigger_object_2) tri3, e3 = self.event_embedding(trigger_3, trigger_agent_3, trigger_object_3) tri4, e4 = self.event_embedding(trigger_4, trigger_agent_4, trigger_object_4) # [batch_size, seq_Len, embedding_size] e = (torch.stack([e0, e1, e2, e3, e4], dim=1)).squeeze(2) # [batch_size, 1, embedding_size] vt = tri4 for i in range(self.args.hop_num): # [batch_size, 1, embedding_size] vt = self.gated_atten(vt, e) # [batch_size, embedding_size] x = vt.view(vt.size(0), -1) # [batch_size, 1] , [batch_size], [batch_size, label_size] score, logits, logits_f1 = self.score(x, tri4) output = {"logits": logits, "score": score} if label is not None: self.accuracy(logits, label) self.f1_score(logits_f1, label) output["loss"] = self.loss_function(score.squeeze(1), label.float()) return output @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: accuracy = self.accuracy.get_metric(reset) precision, recall, f1_measure = self.f1_score.get_metric(reset) return { "accuracy": accuracy, "precision": precision, "recall": recall, "f1_measure": f1_measure }
class BidirectionalAttentionFlow_1(Model): """ This class implements a Bayesian version of Minjoon Seo's `Bidirectional Attention Flow model <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_ for answering reading comprehension questions (ICLR 2017). """ def __init__(self, vocab: Vocabulary, cf_a, preloaded_elmo = None) -> None: super(BidirectionalAttentionFlow_1, self).__init__(vocab, cf_a.regularizer) """ Initialize some data structures """ self.cf_a = cf_a # Bayesian data models self.VBmodels = [] self.LinearModels = [] """ ############## TEXT FIELD EMBEDDER with ELMO #################### text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. """ if (cf_a.use_ELMO): if (type(preloaded_elmo) != type(None)): text_field_embedder = preloaded_elmo else: text_field_embedder = bidut.download_Elmo(cf_a.ELMO_num_layers, cf_a.ELMO_droput ) print ("ELMO loaded from disk or downloaded") else: text_field_embedder = None # embedder_out_dim = text_field_embedder.get_output_dim() self._text_field_embedder = text_field_embedder if(cf_a.Add_Linear_projection_ELMO): if (self.cf_a.VB_Linear_projection_ELMO): prior = Vil.Prior(**(cf_a.VB_Linear_projection_ELMO_prior)) print ("----------------- Bayesian Linear Projection ELMO --------------") linear_projection_ELMO = LinearVB(text_field_embedder.get_output_dim(), 200, prior = prior) self.VBmodels.append(linear_projection_ELMO) else: linear_projection_ELMO = torch.nn.Linear(text_field_embedder.get_output_dim(), 200) self._linear_projection_ELMO = linear_projection_ELMO """ ############## Highway layers #################### num_highway_layers : ``int`` The number of highway layers to use in between embedding the input and passing it through the phrase layer. """ Input_dimension_highway = None if (cf_a.Add_Linear_projection_ELMO): Input_dimension_highway = 200 else: Input_dimension_highway = text_field_embedder.get_output_dim() num_highway_layers = cf_a.num_highway_layers # Linear later to compute the start if (self.cf_a.VB_highway_layers): print ("----------------- Bayesian Highway network --------------") prior = Vil.Prior(**(cf_a.VB_highway_layers_prior)) highway_layer = HighwayVB(Input_dimension_highway, num_highway_layers, prior = prior) self.VBmodels.append(highway_layer) else: highway_layer = Highway(Input_dimension_highway, num_highway_layers) highway_layer = TimeDistributed(highway_layer) self._highway_layer = highway_layer """ ############## Phrase layer #################### phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. """ if cf_a.phrase_layer_dropout > 0: ## Create dropout layer dropout_phrase_layer = torch.nn.Dropout(p=cf_a.phrase_layer_dropout) else: dropout_phrase_layer = lambda x: x phrase_layer = PytorchSeq2SeqWrapper(torch.nn.LSTM(Input_dimension_highway, hidden_size = cf_a.phrase_layer_hidden_size, batch_first=True, bidirectional = True, num_layers = cf_a.phrase_layer_num_layers, dropout = cf_a.phrase_layer_dropout)) phrase_encoding_out_dim = cf_a.phrase_layer_hidden_size * 2 self._phrase_layer = phrase_layer self._dropout_phrase_layer = dropout_phrase_layer """ ############## Matrix attention layer #################### similarity_function : ``SimilarityFunction`` The similarity function that we will use when comparing encoded passage and question representations. """ # Linear later to compute the start if (self.cf_a.VB_similarity_function): prior = Vil.Prior(**(cf_a.VB_similarity_function_prior)) print ("----------------- Bayesian Similarity matrix --------------") similarity_function = LinearSimilarityVB( combination = "x,y,x*y", tensor_1_dim = phrase_encoding_out_dim, tensor_2_dim = phrase_encoding_out_dim, prior = prior) self.VBmodels.append(similarity_function) else: similarity_function = LinearSimilarity( combination = "x,y,x*y", tensor_1_dim = phrase_encoding_out_dim, tensor_2_dim = phrase_encoding_out_dim) matrix_attention = LegacyMatrixAttention(similarity_function) self._matrix_attention = matrix_attention """ ############## Modelling Layer #################### modeling_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end. """ ## Create dropout layer if cf_a.modeling_passage_dropout > 0: ## Create dropout layer dropout_modeling_passage = torch.nn.Dropout(p=cf_a.modeling_passage_dropout) else: dropout_modeling_passage = lambda x: x modeling_layer = PytorchSeq2SeqWrapper(torch.nn.LSTM(phrase_encoding_out_dim * 4, hidden_size = cf_a.modeling_passage_hidden_size, batch_first=True, bidirectional = True, num_layers = cf_a.modeling_passage_num_layers, dropout = cf_a.modeling_passage_dropout)) self._modeling_layer = modeling_layer self._dropout_modeling_passage = dropout_modeling_passage """ ############## Span Start Representation ##################### span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. """ encoding_dim = phrase_layer.get_output_dim() modeling_dim = modeling_layer.get_output_dim() span_start_input_dim = encoding_dim * 4 + modeling_dim # Linear later to compute the start if (self.cf_a.VB_span_start_predictor_linear): prior = Vil.Prior(**(cf_a.VB_span_start_predictor_linear_prior)) print ("----------------- Bayesian Span Start Predictor--------------") span_start_predictor_linear = LinearVB(span_start_input_dim, 1, prior = prior) self.VBmodels.append(span_start_predictor_linear) else: span_start_predictor_linear = torch.nn.Linear(span_start_input_dim, 1) self._span_start_predictor_linear = span_start_predictor_linear self._span_start_predictor = TimeDistributed(span_start_predictor_linear) """ ############## Span End Representation ##################### """ ## Create dropout layer if cf_a.span_end_encoder_dropout > 0: dropout_span_end_encode = torch.nn.Dropout(p=cf_a.span_end_encoder_dropout) else: dropout_span_end_encode = lambda x: x span_end_encoder = PytorchSeq2SeqWrapper(torch.nn.LSTM(encoding_dim * 4 + modeling_dim * 3, hidden_size = cf_a.modeling_span_end_hidden_size, batch_first=True, bidirectional = True, num_layers = cf_a.modeling_span_end_num_layers, dropout = cf_a.span_end_encoder_dropout)) span_end_encoding_dim = span_end_encoder.get_output_dim() span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim self._span_end_encoder = span_end_encoder self._dropout_span_end_encode = dropout_span_end_encode if (self.cf_a.VB_span_end_predictor_linear): print ("----------------- Bayesian Span End Predictor--------------") prior = Vil.Prior(**(cf_a.VB_span_end_predictor_linear_prior)) span_end_predictor_linear = LinearVB(span_end_input_dim, 1, prior = prior) self.VBmodels.append(span_end_predictor_linear) else: span_end_predictor_linear = torch.nn.Linear(span_end_input_dim, 1) self._span_end_predictor_linear = span_end_predictor_linear self._span_end_predictor = TimeDistributed(span_end_predictor_linear) """ Dropput last layers """ if cf_a.spans_output_dropout > 0: dropout_spans_output = torch.nn.Dropout(p=cf_a.span_end_encoder_dropout) else: dropout_spans_output = lambda x: x self._dropout_spans_output = dropout_spans_output """ Checkings and accuracy """ # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily # obvious from the configuration files, so we check here. check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim, "modeling layer input dim", "4 * encoding dim") check_dimensions_match(Input_dimension_highway , phrase_layer.get_input_dim(), "text field embedder output dim", "phrase layer input dim") check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim, "span end encoder input dim", "4 * encoding dim + 3 * modeling dim") self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() """ mask_lstms : ``bool``, optional (default=True) If ``False``, we will skip passing the mask to the LSTM layers. This gives a ~2x speedup, with only a slight performance decrease, if any. We haven't experimented much with this yet, but have confirmed that we still get very similar performance with much faster training times. We still use the mask for all softmaxes, but avoid the shuffling that's required when using masking with pytorch LSTMs. """ self._mask_lstms = cf_a.mask_lstms """ ################### Initialize parameters ############################## """ #### THEY ARE ALL INITIALIZED WHEN INSTANTING THE COMPONENTS ### """ ####################### OPTIMIZER ################ """ optimizer = pytut.get_optimizers(self, cf_a) self._optimizer = optimizer #### TODO: Learning rate scheduler #### #scheduler = optim.ReduceLROnPlateau(optimizer, 'max') def forward_ensemble(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, get_sample_level_information = False) -> Dict[str, torch.Tensor]: """ Sample 10 times and add them together """ self.set_posterior_mean(True) most_likely_output = self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information) self.set_posterior_mean(False) subresults = [most_likely_output] for i in range(10): subresults.append(self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information)) batch_size = len(subresults[0]["best_span"]) best_span = bidut.merge_span_probs(subresults) output = { "best_span": best_span, "best_span_str": [], "models_output": subresults } if (get_sample_level_information): output["em_samples"] = [] output["f1_samples"] = [] for index in range(batch_size): if metadata is not None: passage_str = metadata[index]['original_passage'] offsets = metadata[index]['token_offsets'] predicted_span = tuple(best_span[index].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output["best_span_str"].append(best_span_string) answer_texts = metadata[index].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) if (get_sample_level_information): em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts) output["em_samples"].append(em_sample) output["f1_samples"].append(f1_sample) if (get_sample_level_information): # Add information about the individual samples for future analysis output["span_start_sample_loss"] = [] output["span_end_sample_loss"] = [] for i in range (batch_size): span_start_probs = sum(subresult['span_start_probs'] for subresult in subresults) / len(subresults) span_end_probs = sum(subresult['span_end_probs'] for subresult in subresults) / len(subresults) span_start_loss = nll_loss(span_start_probs[[i],:], span_start.squeeze(-1)[[i]]) span_end_loss = nll_loss(span_end_probs[[i],:], span_end.squeeze(-1)[[i]]) output["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy())) output["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy())) return output def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, get_sample_level_information = False, get_attentions = False) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ """ #################### Sample Bayesian weights ################## """ self.sample_posterior() """ ################## MASK COMPUTING ######################## """ question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None """ ###################### EMBEDDING + HIGHWAY LAYER ######################## """ # self.cf_a.use_ELMO if(self.cf_a.Add_Linear_projection_ELMO): embedded_question = self._highway_layer(self._linear_projection_ELMO (self._text_field_embedder(question['character_ids'])["elmo_representations"][-1])) embedded_passage = self._highway_layer(self._linear_projection_ELMO(self._text_field_embedder(passage['character_ids'])["elmo_representations"][-1])) else: embedded_question = self._highway_layer(self._text_field_embedder(question['character_ids'])["elmo_representations"][-1]) embedded_passage = self._highway_layer(self._text_field_embedder(passage['character_ids'])["elmo_representations"][-1]) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) """ ###################### phrase_layer LAYER ######################## """ encoded_question = self._dropout_phrase_layer(self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout_phrase_layer(self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) """ ###################### Attention LAYER ######################## """ # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector], dim=-1) modeled_passage = self._dropout_modeling_passage(self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) """ ###################### Spans LAYER ######################## """ # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout_spans_output(torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout_span_end_encode(self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout_spans_output(torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = bidut.get_best_span(span_start_logits, span_end_logits) output_dict = { "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) loss = span_start_loss + span_end_loss self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss output_dict["span_start_loss"] = span_start_loss output_dict["span_end_loss"] = span_end_loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: if (get_sample_level_information): output_dict["em_samples"] = [] output_dict["f1_samples"] = [] output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) if (get_sample_level_information): em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts) output_dict["em_samples"].append(em_sample) output_dict["f1_samples"].append(f1_sample) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens if (get_sample_level_information): # Add information about the individual samples for future analysis output_dict["span_start_sample_loss"] = [] output_dict["span_end_sample_loss"] = [] for i in range (batch_size): span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits[[i],:], passage_mask[[i],:]), span_start.squeeze(-1)[[i]]) span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits[[i],:], passage_mask[[i],:]), span_end.squeeze(-1)[[i]]) output_dict["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy())) output_dict["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy())) if(get_attentions): output_dict["C2Q_attention"] = passage_question_attention output_dict["Q2C_attention"] = question_passage_attention output_dict["simmilarity"] = passage_question_similarity return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } def train_batch(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: """ It is enough to just compute the total loss because the normal weights do not depend on the KL Divergence """ # Now we can just compute both losses which will build the dynamic graph output = self.forward(question,passage,span_start,span_end,metadata ) data_loss = output["loss"] KL_div = self.get_KL_divergence() total_loss = self.combine_losses(data_loss, KL_div) self.zero_grad() # zeroes the gradient buffers of all parameters total_loss.backward() if (type(self._optimizer) == type(None)): parameters = filter(lambda p: p.requires_grad, self.parameters()) with torch.no_grad(): for f in parameters: f.data.sub_(f.grad.data * self.lr ) else: # print ("Training") self._optimizer.step() self._optimizer.zero_grad() return output def fill_batch_training_information(self, training_logger, output_batch): """ Function to fill the the training_logger for each batch. training_logger: Dictionary that will hold all the training info output_batch: Output from training the batch """ training_logger["train"]["span_start_loss_batch"].append(output_batch["span_start_loss"].detach().cpu().numpy()) training_logger["train"]["span_end_loss_batch"].append(output_batch["span_end_loss"].detach().cpu().numpy()) training_logger["train"]["loss_batch"].append(output_batch["loss"].detach().cpu().numpy()) # Training metrics: metrics = self.get_metrics() training_logger["train"]["start_acc_batch"].append(metrics["start_acc"]) training_logger["train"]["end_acc_batch"].append(metrics["end_acc"]) training_logger["train"]["span_acc_batch"].append(metrics["span_acc"]) training_logger["train"]["em_batch"].append(metrics["em"]) training_logger["train"]["f1_batch"].append(metrics["f1"]) def fill_epoch_training_information(self, training_logger,device, validation_iterable, num_batches_validation): """ Fill the information per each epoch """ Ntrials_CUDA = 100 # Training Epoch final metrics metrics = self.get_metrics(reset = True) training_logger["train"]["start_acc"].append(metrics["start_acc"]) training_logger["train"]["end_acc"].append(metrics["end_acc"]) training_logger["train"]["span_acc"].append(metrics["span_acc"]) training_logger["train"]["em"].append(metrics["em"]) training_logger["train"]["f1"].append(metrics["f1"]) self.set_posterior_mean(True) self.eval() data_loss_validation = 0 loss_validation = 0 with torch.no_grad(): # Compute the validation accuracy by using all the Validation dataset but in batches. for j in range(num_batches_validation): tensor_dict = next(validation_iterable) trial_index = 0 while (1): try: tensor_dict = pytut.move_to_device(tensor_dict, device) ## Move the tensor to cuda output_batch = self.forward(**tensor_dict) break; except RuntimeError as er: print (er.args) torch.cuda.empty_cache() time.sleep(5) torch.cuda.empty_cache() trial_index += 1 if (trial_index == Ntrials_CUDA): print ("Too many failed trials to allocate in memory") send_error_email(str(er.args)) sys.exit(0) data_loss_validation += output_batch["loss"].detach().cpu().numpy() ## Memmory management !! if (self.cf_a.force_free_batch_memory): del tensor_dict["question"]; del tensor_dict["passage"] del tensor_dict del output_batch torch.cuda.empty_cache() if (self.cf_a.force_call_garbage_collector): gc.collect() data_loss_validation = data_loss_validation/num_batches_validation # loss_validation = loss_validation/num_batches_validation # Training Epoch final metrics metrics = self.get_metrics(reset = True) training_logger["validation"]["start_acc"].append(metrics["start_acc"]) training_logger["validation"]["end_acc"].append(metrics["end_acc"]) training_logger["validation"]["span_acc"].append(metrics["span_acc"]) training_logger["validation"]["em"].append(metrics["em"]) training_logger["validation"]["f1"].append(metrics["f1"]) training_logger["validation"]["data_loss"].append(data_loss_validation) self.train() self.set_posterior_mean(False) def trim_model(self, mu_sigma_ratio = 2): total_size_w = [] total_removed_w = [] total_size_b = [] total_removed_b = [] if (self.cf_a.VB_Linear_projection_ELMO): VBmodel = self._linear_projection_ELMO size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) total_size_w.append(size_w) total_removed_w.append(removed_w) total_size_b.append(size_b) total_removed_b.append(removed_b) if (self.cf_a.VB_highway_layers): VBmodel = self._highway_layer._module.VBmodels[0] Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) total_size_w.append(size_w) total_removed_w.append(removed_w) total_size_b.append(size_b) total_removed_b.append(removed_b) if (self.cf_a.VB_similarity_function): VBmodel = self._matrix_attention._similarity_function Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) total_size_w.append(size_w) total_removed_w.append(removed_w) total_size_b.append(size_b) total_removed_b.append(removed_b) if (self.cf_a.VB_span_start_predictor_linear): VBmodel = self._span_start_predictor_linear Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) total_size_w.append(size_w) total_removed_w.append(removed_w) total_size_b.append(size_b) total_removed_b.append(removed_b) if (self.cf_a.VB_span_end_predictor_linear): VBmodel = self._span_end_predictor_linear Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) total_size_w.append(size_w) total_removed_w.append(removed_w) total_size_b.append(size_b) total_removed_b.append(removed_b) return total_size_w, total_removed_w, total_size_b, total_removed_b # print (weights_to_remove_W.shape) """ BAYESIAN NECESSARY FUNCTIONS """ sample_posterior = GeneralVBModel.sample_posterior get_KL_divergence = GeneralVBModel.get_KL_divergence set_posterior_mean = GeneralVBModel.set_posterior_mean combine_losses = GeneralVBModel.combine_losses def save_VB_weights(self): """ Function that saves only the VB weights of the model. """ pretrained_dict = ... model_dict = self.state_dict() # 1. filter out unnecessary keys pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict self.load_state_dict(pretrained_dict)
def test_does_not_divide_by_zero_with_no_count(self, device: str): accuracy = BooleanAccuracy() self.assertAlmostEqual(accuracy.get_metric(), 0.0)
class BertSpanPointerResolution(Model): """该模型同时预测mask位置以及span的起始位置""" def __init__(self, vocab: Vocabulary, model_name: str = None, start_attention: Attention = None, end_attention: Attention = None, text_field_embedder: TextFieldEmbedder = None, task_pretrained_file: str = None, neg_sample_ratio: float = 0.0, max_turn_len: int = 3, start_token: str = "[CLS]", end_token: str = "[SEP]", index_name: str = "bert", eps: float = 1e-8, seed: int = 42, loss_factor: float = 1.0, initializer: InitializerApplicator = InitializerApplicator(), regularizer: RegularizerApplicator = None): super().__init__(vocab, regularizer) if model_name is None and text_field_embedder is None: raise ValueError( f"`model_name` and `text_field_embedder` can't both equal to None." ) # 单纯的resolution任务,只需要返回最后一层的embedding表征即可 self._text_field_embedder = text_field_embedder or PretrainedChineseBertMismatchedEmbedder( model_name, return_all=False, output_hidden_states=False, max_turn_length=max_turn_len) seed_everything(seed) self._neg_sample_ratio = neg_sample_ratio self._start_token = start_token self._end_token = end_token self._index_name = index_name self._initializer = initializer linear_input_size = self._text_field_embedder.get_output_dim() # 使用attention的方法 self.start_attention = start_attention or BilinearAttention( vector_dim=linear_input_size, matrix_dim=linear_input_size) self.end_attention = end_attention or BilinearAttention( vector_dim=linear_input_size, matrix_dim=linear_input_size) # mask的指标,主要考虑F-score,而且我们更加关注`1`的召回率 self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._rewrite_em = RewriteEM(valid_keys="semr,nr_semr,re_semr") self._restore_score = RestorationScore(compute_restore_tokens=True) self._metrics = [ TokenBasedBLEU(mode="1,2"), TokenBasedROUGE(mode="1r,2r") ] self._eps = eps self._loss_factor = loss_factor self._initializer(self.start_attention) self._initializer(self.end_attention) # 加载其他任务预训练的模型 if task_pretrained_file is not None and os.path.isfile( task_pretrained_file): logger.info("loading related task pretrained weights...") self.load_state_dict(torch.load(task_pretrained_file), strict=False) def _calc_loss(self, span_start_logits: torch.Tensor, span_end_logits: torch.Tensor, use_mask_label: torch.Tensor, start_label: torch.Tensor, end_label: torch.Tensor, best_spans: torch.Tensor): batch_size = start_label.size(0) # 常规loss loss_fct = nn.CrossEntropyLoss(reduction="none", ignore_index=-1) # --- 计算start和end标签对应的loss --- # 选择出mask_label等于1的位置对应的start和end的结果 # [B_mask, ] span_start_label = start_label.masked_select( use_mask_label.to(dtype=torch.bool)) span_end_label = end_label.masked_select( use_mask_label.to(dtype=torch.bool)) # mask掉大部分为0的标签来计算准确率 train_span_mask = (span_start_label != -1) # [B_mask, 2] answer_spans = torch.stack([span_start_label, span_end_label], dim=-1) self._span_accuracy( best_spans, answer_spans, train_span_mask.unsqueeze(-1).expand_as(best_spans)) # -- 计算start_loss -- start_losses = loss_fct(span_start_logits, span_start_label) # start_label_weight = self._calc_loss_weight(span_start_label) # 计算标签的weight start_loss = torch.sum(start_losses) / batch_size # 对loss的值进行检查 big_constant = min(torch.finfo(start_loss.dtype).max, 1e9) if torch.any(start_loss > big_constant): logger.critical("Start loss too high (%r)", start_loss) logger.critical("span_start_logits: %r", span_start_logits) logger.critical("span_start: %r", span_start_label) assert False # -- 计算end_loss -- end_losses = loss_fct(span_end_logits, span_end_label) # end_label_weight = self._calc_loss_weight(span_end_label) # 计算标签的weight end_loss = torch.sum(end_losses) / batch_size if torch.any(end_loss > big_constant): logger.critical("End loss too high (%r)", end_loss) logger.critical("span_end_logits: %r", span_end_logits) logger.critical("span_end: %r", span_end_label) assert False span_loss = (start_loss + end_loss) / 2 self._span_start_accuracy(span_start_logits, span_start_label, train_span_mask) self._span_end_accuracy(span_end_logits, span_end_label, train_span_mask) loss = span_loss return loss def _calc_loss_weight(self, label: torch.Tensor): label_mask = (label != 0).to(torch.float16) label_weight = label_mask * self._loss_factor + 1.0 return label_weight def _get_rewrite_result(self, use_mask_label: torch.Tensor, best_spans: torch.Tensor, query_lens: torch.Tensor, context_lens: torch.Tensor, metadata: List[Dict[str, Any]]): # 将两个标签转换成numpy类型 # [B, query_len] use_mask_label = use_mask_label.detach().cpu().numpy() # [B_mask, 2] best_spans = best_spans.detach().cpu().numpy().tolist() predict_rewrite_results = [] for cur_query_len, cur_context_len, cur_query_mask_labels, mdata in zip( query_lens, context_lens, use_mask_label, metadata): context_tokens = mdata['context_tokens'] query_tokens = mdata['query_tokens'] cur_rewrite_result = copy.deepcopy(query_tokens) already_insert_tokens = 0 # 记录已经插入的tokens的数量 already_insert_min_start = cur_context_len # 表示当前已经添加过的信息的最小的start already_insert_max_end = 0 # 表示当前已经添加过的信息的最大的end # 遍历当前mask的所有标签,如果标签为1,则计算对应的span_string for i in range(cur_query_len): cur_mask_label = cur_query_mask_labels[i] # 只有当预测的label为1时,才进行补充 if cur_mask_label: predict_start, predict_end = best_spans.pop(0) # 如果都为0则继续 if predict_start == 0 and predict_end == 0: continue # 如果start大于长度,则继续 if predict_start >= cur_context_len: continue # 如果当前想要插入的信息,在之前已经插入过信息的内部,则不再插入 if predict_start >= already_insert_min_start and predict_end <= already_insert_max_end: continue # 对位置进行矫正 if predict_start < 0 or context_tokens[ predict_start] == self._start_token: predict_start = 1 if predict_end >= cur_context_len: predict_end = cur_context_len - 1 # 获取预测的span predict_span_tokens = context_tokens[ predict_start:predict_end + 1] # 更新已经插入的最小的start和最大的end if predict_start < already_insert_min_start: already_insert_min_start = predict_start if predict_end > already_insert_max_end: already_insert_max_end = predict_end # 再对预测的span按照要求进行矫正,只取end_token之前的所有tokens try: index = predict_span_tokens.index(self._end_token) predict_span_tokens = predict_span_tokens[:index] except BaseException: pass # 获取当前span插入的位置 # 如果是要插入到当前位置后面,则需要+1 # 如果是要插入到当前位置前面,则不需要 cur_insert_index = i + already_insert_tokens cur_rewrite_result = cur_rewrite_result[:cur_insert_index] + \ predict_span_tokens + cur_rewrite_result[cur_insert_index:] # 记录插入的tokens的数量 already_insert_tokens += len(predict_span_tokens) cur_rewrite_result = cur_rewrite_result[:-1] # 不再以list of tokens的形式 # 而是以string的形式去计算 cur_rewrite_string = "".join(cur_rewrite_result) rewrite_tokens = mdata.get("rewrite_tokens", None) if rewrite_tokens is not None: rewrite_string = "".join(rewrite_tokens) # 去除[SEP]这个token query_string = "".join(query_tokens[:-1]) self._rewrite_em(cur_rewrite_string, rewrite_string, query_string) # 额外增加的指标 for metric in self._metrics: metric(cur_rewrite_result, rewrite_tokens) # 获取restore_tokens并计算对应的指标 restore_tokens = mdata.get("restore_tokens", None) self._restore_score(cur_rewrite_result, rewrite_tokens, queries=query_tokens[:-1], restore_tokens=restore_tokens) predict_rewrite_results.append("".join(cur_rewrite_result)) return predict_rewrite_results @overrides def forward(self, context_ids: TextFieldTensors, query_ids: TextFieldTensors, context_lens: torch.Tensor, query_lens: torch.Tensor, mask_label: Optional[torch.Tensor] = None, start_label: Optional[torch.Tensor] = None, end_label: Optional[torch.Tensor] = None, metadata: List[Dict[str, Any]] = None): # concat the context and query to the encoder # get the indexers first indexers = context_ids.keys() dialogue_ids = {} # 获取context和query的长度 context_len = torch.max(context_lens).item() query_len = torch.max(query_lens).item() # [B, _len] context_mask = get_mask_from_sequence_lengths(context_lens, context_len) query_mask = get_mask_from_sequence_lengths(query_lens, query_len) for indexer in indexers: # get the various variables of context and query dialogue_ids[indexer] = {} for key in context_ids[indexer].keys(): context = context_ids[indexer][key] query = query_ids[indexer][key] # concat the context and query in the length dim dialogue = torch.cat([context, query], dim=1) dialogue_ids[indexer][key] = dialogue # get the outputs of the dialogue if isinstance(self._text_field_embedder, TextFieldEmbedder): embedder_outputs = self._text_field_embedder(dialogue_ids) else: embedder_outputs = self._text_field_embedder( **dialogue_ids[self._index_name]) # get the outputs of the query and context # [B, _len, embed_size] context_last_layer = embedder_outputs[:, :context_len].contiguous() query_last_layer = embedder_outputs[:, context_len:].contiguous() # ------- 计算span预测的结果 ------- # 我们想要知道query中的每一个mask位置的token后面需要补充的内容 # 也就是其对应的context中span的start和end的位置 # 同理,将context扩展成 [b, query_len, context_len, embed_size] context_last_layer = context_last_layer.unsqueeze(dim=1).expand( -1, query_len, -1, -1).contiguous() # [b, query_len, context_len] context_expand_mask = context_mask.unsqueeze(dim=1).expand( -1, query_len, -1).contiguous() # 将上面3个部分拼接在一起 # 这里表示query中所有的position span_embed_size = context_last_layer.size(-1) if self.training and self._neg_sample_ratio > 0.0: # 对mask中0的位置进行采样 # [B*query_len, ] sample_mask_label = mask_label.view(-1) # 获取展开之后的长度以及需要采样的负样本的数量 mask_length = sample_mask_label.size(0) mask_sum = int( torch.sum(sample_mask_label).item() * self._neg_sample_ratio) mask_sum = max(10, mask_sum) # 获取需要采样的负样本的索引 neg_indexes = torch.randint(low=0, high=mask_length, size=(mask_sum, )) # 限制在长度范围内 neg_indexes = neg_indexes[:mask_length] # 将负样本对应的位置mask置为1 sample_mask_label[neg_indexes] = 1 # [B, query_len] use_mask_label = sample_mask_label.view( -1, query_len).to(dtype=torch.bool) # 过滤掉query中pad的部分, [B, query_len] use_mask_label = use_mask_label & query_mask span_mask = use_mask_label.unsqueeze(dim=-1).unsqueeze(dim=-1) # 选择context部分可以使用的内容 # [B_mask, context_len, span_embed_size] span_context_matrix = context_last_layer.masked_select( span_mask).view(-1, context_len, span_embed_size).contiguous() # 选择query部分可以使用的向量 span_query_vector = query_last_layer.masked_select( span_mask.squeeze(dim=-1)).view(-1, span_embed_size).contiguous() span_context_mask = context_expand_mask.masked_select( span_mask.squeeze(dim=-1)).view(-1, context_len).contiguous() else: use_mask_label = query_mask span_mask = use_mask_label.unsqueeze(dim=-1).unsqueeze(dim=-1) # 选择context部分可以使用的内容 # [B_mask, context_len, span_embed_size] span_context_matrix = context_last_layer.masked_select( span_mask).view(-1, context_len, span_embed_size).contiguous() # 选择query部分可以使用的向量 span_query_vector = query_last_layer.masked_select( span_mask.squeeze(dim=-1)).view(-1, span_embed_size).contiguous() span_context_mask = context_expand_mask.masked_select( span_mask.squeeze(dim=-1)).view(-1, context_len).contiguous() # 得到span属于每个位置的logits # [B_mask, context_len] span_start_probs = self.start_attention(span_query_vector, span_context_matrix, span_context_mask) span_end_probs = self.end_attention(span_query_vector, span_context_matrix, span_context_mask) span_start_logits = torch.log(span_start_probs + self._eps) span_end_logits = torch.log(span_end_probs + self._eps) # [B_mask, 2],最后一个维度第一个表示start的位置,第二个表示end的位置 best_spans = get_best_span(span_start_logits, span_end_logits) # 计算得到每个best_span的分数 best_span_scores = ( torch.gather(span_start_logits, 1, best_spans[:, 0].unsqueeze(1)) + torch.gather(span_end_logits, 1, best_spans[:, 1].unsqueeze(1))) # [B_mask, ] best_span_scores = best_span_scores.squeeze(1) # 将重要的信息写入到输出中 output_dict = { "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_spans": best_spans, "best_span_scores": best_span_scores } # 如果存在标签,则使用标签计算loss if start_label is not None: loss = self._calc_loss(span_start_logits, span_end_logits, use_mask_label, start_label, end_label, best_spans) output_dict["loss"] = loss if metadata is not None: predict_rewrite_results = self._get_rewrite_result( use_mask_label, best_spans, query_lens, context_lens, metadata) output_dict['rewrite_results'] = predict_rewrite_results return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics = {} metrics["span_acc"] = self._span_accuracy.get_metric(reset) for metric in self._metrics: metrics.update(metric.get_metric(reset)) metrics.update(self._rewrite_em.get_metric(reset)) metrics.update(self._restore_score.get_metric(reset)) return metrics @overrides def make_output_human_readable( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: new_output_dict = {} new_output_dict["rewrite_results"] = output_dict["rewrite_results"] return new_output_dict
class AttentionLSTM(Model): def __init__(self, args, word_embeddings: TextFieldEmbedder, vocab: Vocabulary, domain_info: bool = True) -> None: super().__init__(vocab) # parameters self.args = args self.word_embeddings = word_embeddings self.domain = domain_info # layers self.event_embedding = EventEmbedding(args, self.word_embeddings) self.event_type_embedding = EventTypeEmbedding(args, self.word_embeddings) self.lstm = LSTM(input_size=self.args.embedding_size, hidden_size=self.args.hidden_size) self.W_c = Linear(self.args.embedding_size, self.args.hidden_size, bias=False) self.W_e = Linear(self.args.hidden_size, self.args.hidden_size, bias=False) self.relu = ReLU() self.linear = Linear(self.args.hidden_size, self.args.embedding_size) self.attention = Attention(self.args.hidden_size, score_function='mlp') self.score = Score(self.args.embedding_size, self.args.embedding_size, threshold=self.args.threshold) # metrics self.accuracy = BooleanAccuracy() self.f1_score = F1Measure(positive_label=1) self.loss_function = BCELoss() @overrides def forward(self, trigger_0: Dict[str, torch.LongTensor], trigger_agent_0: Dict[str, torch.LongTensor], agent_attri_0: Dict[str, torch.LongTensor], trigger_object_0: Dict[str, torch.LongTensor], object_attri_0: Dict[str, torch.LongTensor], trigger_1: Dict[str, torch.LongTensor], trigger_agent_1: Dict[str, torch.LongTensor], agent_attri_1: Dict[str, torch.LongTensor], trigger_object_1: Dict[str, torch.LongTensor], object_attri_1: Dict[str, torch.LongTensor], trigger_2: Dict[str, torch.LongTensor], trigger_agent_2: Dict[str, torch.LongTensor], agent_attri_2: Dict[str, torch.LongTensor], trigger_object_2: Dict[str, torch.LongTensor], object_attri_2: Dict[str, torch.LongTensor], trigger_3: Dict[str, torch.LongTensor], trigger_agent_3: Dict[str, torch.LongTensor], agent_attri_3: Dict[str, torch.LongTensor], trigger_object_3: Dict[str, torch.LongTensor], object_attri_3: Dict[str, torch.LongTensor], trigger_4: Dict[str, torch.LongTensor], trigger_agent_4: Dict[str, torch.LongTensor], agent_attri_4: Dict[str, torch.LongTensor], trigger_object_4: Dict[str, torch.LongTensor], object_attri_4: Dict[str, torch.LongTensor], event_type: Dict[str, torch.LongTensor], label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # tri, e: [batch_size, 1, embedding_size] tri0, e0 = self.event_embedding(trigger_0, trigger_agent_0, trigger_object_0) tri1, e1 = self.event_embedding(trigger_1, trigger_agent_1, trigger_object_1) tri2, e2 = self.event_embedding(trigger_2, trigger_agent_2, trigger_object_2) tri3, e3 = self.event_embedding(trigger_3, trigger_agent_3, trigger_object_3) tri4, e4 = self.event_embedding(trigger_4, trigger_agent_4, trigger_object_4) event_type = self.event_type_embedding(event_type) # [batch_size, seq_Len, embedding_size] e = (torch.stack([e0, e1, e2, e3], dim=1)).squeeze(2) batch_size, seq_len, _ = e.size() # [batch_size, seq_len, embedding_size] event_types = (torch.stack( [event_type, event_type, event_type, event_type], dim=1)).squeeze(2) # [seq_Len, batch_size, embedding_size] e = e.view(seq_len, batch_size, -1) lstm_out, (hn, _) = self.lstm(e) # [batch_size, seq_len, hidden_size] lstm_out = lstm_out.view(batch_size, seq_len, -1) if self.domain: lstm_out = lstm_out + self.relu( self.W_c(event_types) + self.W_e(lstm_out)) # [batch_size, 1, hidden_size] hn = hn.view(batch_size, 1, -1) # [batch_size, 1, hidden_size] out_atten, _ = self.attention(lstm_out, hn) # [batch_size, 1, embedding_size] out_atten = self.linear(out_atten) # [batch_size, 1] , [batch_size], [batch_size, label_size] score, logits, logits_f1 = self.score(out_atten, e4) output = {"logits": logits, "score": score} if label is not None: self.accuracy(logits, label) self.f1_score(logits_f1, label) output["loss"] = self.loss_function(score.squeeze(1), label.float()) return output @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: accuracy = self.accuracy.get_metric(reset) precision, recall, f1_measure = self.f1_score.get_metric(reset) return { "accuracy": accuracy, "precision": precision, "recall": recall, "f1_measure": f1_measure }
class BERT_QA(Model): def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, dropout: float = 0.0, max_span_length: int = 30, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super().__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._max_span_length = max_span_length self.qa_outputs = torch.nn.Linear( self._text_field_embedder.get_output_dim(), 2) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._span_qa_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x initializer(self) def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], context: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # the `context` is the concact of `question` and `passage`, so we just use `context` batch_size, num_of_passage_tokens = context['tokens'].size() # BERT for QA is a fully connected linear layer on top of BERT producing 2 vectors of # start and end spans. embedded_passage = self._text_field_embedder(context) passage_length = embedded_passage.size(1) logits = self.qa_outputs(embedded_passage) start_logits, end_logits = logits.split(1, dim=-1) span_start_logits = start_logits.squeeze(-1) span_end_logits = end_logits.squeeze(-1) # Adding some masks with numerically stable values passage_mask = util.get_text_field_mask(passage).float() repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, 1, 1) repeated_passage_mask = repeated_passage_mask.view( batch_size, passage_length) span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7) span_start_probs = util.masked_softmax(span_start_logits, repeated_passage_mask) span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7) span_end_probs = util.masked_softmax(span_end_logits, repeated_passage_mask) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict: Dict[str, Any] = {} output_dict = { "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # compute the loss for training. if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.cat([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on span qa and add the tokenized input to the output. if metadata is not None: output_dict["best_span_str"] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]["question_tokens"]) passage_tokens.append(metadata[i]["passage_tokens"]) passage_words = metadata[i]["paragraph_words"] answer_offset = metadata[i]["answer_offset"] tok_to_word_index = metadata[i]["tok_to_word_index"] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_position = tok_to_word_index[predicted_span[0] - answer_offset] end_position = tok_to_word_index[predicted_span[1] - answer_offset] best_span_str = " ".join( passage_words[start_position:end_position + 1]) output_dict["best_span_str"].append(best_span_str) answer_text = metadata[i].get("answer_text", []) if answer_text: answer_text = [answer_text] self._span_qa_metrics(best_span_str, answer_text) output_dict["question_tokens"] = question_tokens output_dict["passage_tokens"] = passage_tokens return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._span_qa_metrics.get_metric(reset) return { "start_acc": self._span_start_accuracy.get_metric(reset), "end_acc": self._span_end_accuracy.get_metric(reset), "span_acc": self._span_accuracy.get_metric(reset), "em": exact_match, "f1": f1_score, } @staticmethod def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor: # We call the inputs "logits" - they could either be unnormalized logits or normalized log # probabilities. A log_softmax operation is a constant shifting of the entire logit # vector, so taking an argmax over either one gives the same result. if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() device = span_start_logits.device # (batch_size, passage_length, passage_length) span_log_probs = span_start_logits.unsqueeze( 2) + span_end_logits.unsqueeze(1) # Only the upper triangle of the span matrix is valid; the lower triangle has entries where # the span ends before it starts. span_log_mask = (torch.triu( torch.ones((passage_length, passage_length), device=device)).log().unsqueeze(0)) valid_span_log_probs = span_log_probs + span_log_mask # Here we take the span matrix and flatten it, then find the best span using argmax. We # can recover the start and end indices from this flattened list using simple modular # arithmetic. # (batch_size, passage_length * passage_length) best_spans = valid_span_log_probs.view(batch_size, -1).argmax(-1) span_start_indices = best_spans // passage_length span_end_indices = best_spans % passage_length return torch.stack([span_start_indices, span_end_indices], dim=-1)
class ModelSQUAD(Model): def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, attention_similarity_function: SimilarityFunction, residual_encoder: Seq2SeqEncoder, span_start_encoder: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, feed_forward: FeedForward, dropout: float = 0.2, mask_lstms: bool = True, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(ModelSQUAD, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._highway_layer = TimeDistributed( Highway(text_field_embedder.get_output_dim(), num_highway_layers)) self._phrase_layer = phrase_layer self._matrix_attention = MatrixAttention(attention_similarity_function) self._residual_encoder = residual_encoder self._span_end_encoder = span_end_encoder self._span_start_encoder = span_start_encoder self._feed_forward = feed_forward encoding_dim = phrase_layer.get_output_dim() self._span_start_predictor = TimeDistributed( torch.nn.Linear(encoding_dim, 1)) span_end_encoding_dim = span_end_encoder.get_output_dim() self._span_end_predictor = TimeDistributed( torch.nn.Linear(encoding_dim, 1)) self._no_answer_predictor = TimeDistributed( torch.nn.Linear(encoding_dim, 1)) self._self_matrix_attention = MatrixAttention( attention_similarity_function) self._linear_layer = TimeDistributed( torch.nn.Linear(4 * encoding_dim, encoding_dim)) self._residual_linear_layer = TimeDistributed( torch.nn.Linear(3 * encoding_dim, encoding_dim)) self._w_x = torch.nn.Parameter(torch.Tensor(encoding_dim)) self._w_y = torch.nn.Parameter(torch.Tensor(encoding_dim)) self._w_xy = torch.nn.Parameter(torch.Tensor(encoding_dim)) std = math.sqrt(6 / (encoding_dim * 3 + 1)) self._w_x.data.uniform_(-std, std) self._w_y.data.uniform_(-std, std) self._w_xy.data.uniform_(-std, std) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms initializer(self) def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.LongTensor = None, span_end: torch.LongTensor = None, spans=None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ embedded_question = self._highway_layer( self._text_field_embedder(question)) # Shape: (batch_size, 4, passage_length, embedding_dim) embedded_passage = self._text_field_embedder(passage) (batch_size, q_length, embedding_dim) = embedded_question.size() passage_length = embedded_passage.size(2) # reshape: (batch_size*4, -1, embedding_dim) embedded_passage = embedded_passage.view(-1, passage_length, embedding_dim) embedded_passage = self._highway_layer(embedded_passage) embedded_question = embedded_question.unsqueeze(0).expand( 4, -1, -1, -1).contiguous().view(-1, q_length, embedding_dim) question_mask = util.get_text_field_mask(question).float() question_mask = question_mask.unsqueeze(0).expand( 4, -1, -1).contiguous().view(-1, q_length) passage_mask = util.get_text_field_mask(passage, 1).float() passage_mask = passage_mask.view(-1, passage_length) question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) cuda_device = encoded_question.get_device() # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.last_dim_softmax( passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector ], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) question_attended_passage = relu( self._linear_layer(final_merged_passage)) # TODO: attach residual self-attention layer # Shape: (batch_size, passage_length, encoding_dim) residual_passage = self._dropout( self._residual_encoder(self._dropout(question_attended_passage), passage_lstm_mask)) mask = passage_mask.resize(batch_size, passage_length, 1) * passage_mask.resize( batch_size, 1, passage_length) self_mask = Variable( torch.eye(passage_length, passage_length).cuda(cuda_device)).resize( 1, passage_length, passage_length) mask = mask * (1 - self_mask) # Shape: (batch_size, passage_length, passage_length) x_similarity = torch.matmul(residual_passage, self._w_x).unsqueeze(2) y_similarity = torch.matmul(residual_passage, self._w_y).unsqueeze(1) dot_similarity = torch.bmm(residual_passage * self._w_xy, residual_passage.transpose(1, 2)) passage_self_similarity = dot_similarity + x_similarity + y_similarity #for i in range(passage_length): # passage_self_similarity[:, i, i] = float('-Inf') # Shape: (batch_size, passage_length, passage_length) passage_self_attention = util.last_dim_softmax(passage_self_similarity, mask) # Shape: (batch_size, passage_length, encoding_dim) passage_vectors = util.weighted_sum(residual_passage, passage_self_attention) # Shape: (batch_size, passage_length, encoding_dim * 3) merged_passage = torch.cat([ residual_passage, passage_vectors, residual_passage * passage_vectors ], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) self_attended_passage = relu( self._residual_linear_layer(merged_passage)) # Shape: (batch_size, passage_length, encoding_dim) mixed_passage = question_attended_passage + self_attended_passage # Shape: (batch_size, passage_length, encoding_dim) encoded_span_start = self._dropout( self._span_start_encoder(mixed_passage, passage_lstm_mask)) span_start_logits = self._span_start_predictor( encoded_span_start).squeeze(-1) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, passage_length, encoding_dim * 2) concatenated_passage = torch.cat([mixed_passage, encoded_span_start], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout( self._span_end_encoder(concatenated_passage, passage_lstm_mask)) span_end_logits = self._span_end_predictor(encoded_span_end).squeeze( -1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) # Shape: (batch_size, encoding_dim) v_1 = util.weighted_sum(encoded_span_start, span_start_probs) v_2 = util.weighted_sum(encoded_span_end, span_end_probs) no_span_logits = self._no_answer_predictor( self_attended_passage).squeeze(-1) no_span_probs = util.masked_softmax(no_span_logits, passage_mask) v_3 = util.weighted_sum(self_attended_passage, no_span_probs) # Shape: (batch_size, 1) z_score = self._feed_forward(torch.cat([v_1, v_2, v_3], dim=-1)) # compute no-answer score span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # create target tensor including no-answer label span_target = Variable(torch.ones(batch_size).long()).cuda(cuda_device) for b in range(batch_size): span_target[b].data[0] = span_start[ b, 0].data[0] * passage_length + span_end[b, 0].data[0] span_target[span_target < 0] = passage_length**2 # Shape: (batch_size, passage_length, passage_length) span_start_logits_tiled = span_start_logits.unsqueeze(1).expand( batch_size, passage_length, passage_length) span_end_logits_tiled = span_end_logits.unsqueeze(-1).expand( batch_size, passage_length, passage_length) span_logits = (span_start_logits_tiled + span_end_logits_tiled).view( batch_size, -1) answer_mask = torch.bmm(passage_mask.unsqueeze(-1), passage_mask.unsqueeze(1)).view( batch_size, -1) no_answer_mask = Variable(torch.ones(batch_size, 1)).cuda(cuda_device) combined_mask = torch.cat([answer_mask, no_answer_mask], dim=1) all_logits = torch.cat([span_logits, z_score], dim=-1) loss = nll_loss(util.masked_log_softmax(all_logits, combined_mask), span_target) output_dict["loss"] = loss # Shape(batch_size, max_answers, num_span) # max_answers = spans.size(1) # span_logits = torch.bmm(span_start_logits.unsqueeze(-1), span_end_logits.unsqueeze(1)).view(batch_size, -1) # answer_mask = torch.bmm(passage_mask.unsqueeze(-1), passage_mask.unsqueeze(1)).view(batch_size, -1) # no_answer_mask = Variable(torch.ones(batch_size, 1)).cuda(cuda_device) # combined_mask = torch.cat([answer_mask, no_answer_mask], dim=1) # # Shape: (batch_size, passage_length**2 + 1) # all_logits = torch.cat([span_logits, z_score], dim=-1) # # Shape: (batch_size, max_answers) # spans_combined = spans[:, :, 0] * passage_length + spans[:, :, 1] # spans_combined[spans_combined < 0] = passage_length*passage_length # # all_modified_logits = [] # for b in range(batch_size): # idxs = Variable(torch.LongTensor(range(passage_length**2 + 1))).cuda(cuda_device) # for i in range(max_answers): # idxs[spans_combined[b, i].data[0]].data = idxs[spans_combined[b, 0].data[0]].data # idxs[passage_length**2].data[0] = passage_length**2 # modified_logits = Variable(torch.zeros(all_logits.size(-1))).cuda(cuda_device) # modified_logits.index_add_(0, idxs, all_logits[b]) # all_modified_logits.append(modified_logits) # all_modified_logits = torch.stack(all_modified_logits, dim=0) # loss = nll_loss(util.masked_log_softmax(all_modified_logits, combined_mask), spans_combined[:, 0]) # output_dict["loss"] = loss if span_start is not None: self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].data.cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } @staticmethod def get_best_span(span_start_logits: Variable, span_end_logits: Variable) -> Variable: if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = Variable(span_start_logits.data.new().resize_( batch_size, 2).fill_(0)).long() span_start_logits = span_start_logits.data.cpu().numpy() span_end_logits = span_end_logits.data.cpu().numpy() for b in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b, span_start_argmax[b]] if val1 < span_start_logits[b, j]: span_start_argmax[b] = j val1 = span_start_logits[b, j] val2 = span_end_logits[b, j] if val1 + val2 > max_span_log_prob[b]: best_word_span[b, 0] = span_start_argmax[b] best_word_span[b, 1] = j max_span_log_prob[b] = val1 + val2 return best_word_span @classmethod def from_params(cls, vocab: Vocabulary, params: Params) -> 'ModelSQUAD': embedder_params = params.pop("text_field_embedder") text_field_embedder = TextFieldEmbedder.from_params( vocab, embedder_params) num_highway_layers = params.pop_int("num_highway_layers") phrase_layer = Seq2SeqEncoder.from_params(params.pop("phrase_layer")) similarity_function = SimilarityFunction.from_params( params.pop("similarity_function")) residual_encoder = Seq2SeqEncoder.from_params( params.pop("residual_encoder")) span_start_encoder = Seq2SeqEncoder.from_params( params.pop("span_start_encoder")) span_end_encoder = Seq2SeqEncoder.from_params( params.pop("span_end_encoder")) feed_forward = FeedForward.from_params(params.pop("feed_forward")) dropout = params.pop_float('dropout', 0.2) initializer = InitializerApplicator.from_params( params.pop('initializer', [])) regularizer = RegularizerApplicator.from_params( params.pop('regularizer', [])) mask_lstms = params.pop_bool('mask_lstms', True) params.assert_empty(cls.__name__) return cls(vocab=vocab, text_field_embedder=text_field_embedder, num_highway_layers=num_highway_layers, phrase_layer=phrase_layer, attention_similarity_function=similarity_function, residual_encoder=residual_encoder, span_start_encoder=span_start_encoder, span_end_encoder=span_end_encoder, feed_forward=feed_forward, dropout=dropout, mask_lstms=mask_lstms, initializer=initializer, regularizer=regularizer)
class CMVDiscriminator(FeedForward): def __init__(self, input_dim: int, num_layers: int, hidden_dims: Union[int, Sequence[int]], activations: Union[Activation, Sequence[Activation]], dropout: Union[float, Sequence[float]] = 0.0, gate_bias: float = -2) -> None: super(CMVDiscriminator, self).__init__(input_dim, num_layers, hidden_dims, activations, dropout) if not isinstance(hidden_dims, list): hidden_dims = [hidden_dims] * (num_layers - 1) input_dims = hidden_dims[1:] gate_layers = [None] #so we can zip this later for layer_input_dim, layer_output_dim in zip(input_dims, hidden_dims): gate_layer = torch.nn.Linear(layer_input_dim, layer_output_dim) gate_layer.bias.data.fill_(gate_bias) gate_layers.append(gate_layer) self._gate_layers = torch.nn.ModuleList(gate_layers) #feedforward requires an Activation so we just use the identity self._output_feedforward = FeedForward(hidden_dims[-1], 1, 1, lambda x: x) self._accuracy = BooleanAccuracy() def _get_hidden(self, output): layers = list( zip(self._linear_layers, self._activations, self._dropout, self._gate_layers)) layer, activation, dropout, _ = layers[0] output = dropout(activation(layer(output))) for layer, activation, dropout, gate in layers[1:]: gate_output = torch.sigmoid(gate(output)) new_output = dropout(activation(layer(output))) output = torch.add(torch.mul(gate_output, new_output), torch.mul(1 - gate_output, output)) return output def forward(self, real_output, fake_output=None): real_hidden = self._get_hidden(real_output) real_value = self._output_feedforward(real_hidden) labels = torch.ones(real_hidden.size(0)) if torch.cuda.is_available() and real_value.is_cuda: idx = real_value.get_device() labels = labels.cuda(idx) loss = torch.nn.functional.binary_cross_entropy_with_logits( real_value.view(-1), labels) predictions = torch.sigmoid(real_value) > 0.5 if fake_output is not None: fake_hidden = self._get_hidden(fake_output) fake_value = self._output_feedforward(fake_hidden) fake_labels = torch.zeros(fake_hidden.size(0)) if torch.cuda.is_available() and fake_value.is_cuda: idx = fake_value.get_device() fake_labels = fake_labels.cuda(idx) loss += torch.nn.functional.binary_cross_entropy_with_logits( fake_value.view(-1), fake_labels) predictions = torch.cat( [predictions, torch.sigmoid(fake_value) > 0.5]) labels = torch.cat([labels, fake_labels]) self._accuracy(predictions, labels.byte()) return {'loss': loss, 'predictions': predictions, 'labels': labels} def get_metrics(self, reset: bool = False) -> Dict[str, float]: return {'accuracy': self._accuracy.get_metric(reset)}
class Seq2SeqTask(SequenceGenerationTask): """Sequence-to-sequence Task""" def __init__(self, path, max_seq_len, max_targ_v_size, name, **kw): """ """ super().__init__(name, **kw) self.scorer2 = BooleanAccuracy() self.scorers.append(self.scorer2) self.val_metric = "%s_accuracy" % self.name self.val_metric_decreases = False self.max_seq_len = max_seq_len self._label_namespace = self.name + "_tokens" self.max_targ_v_size = max_targ_v_size self.target_indexer = {"words": SingleIdTokenIndexer(namespace=self._label_namespace)} self.files_by_split = { split: os.path.join(path, "%s.tsv" % split) for split in ["train", "val", "test"] } # The following is necessary since word-level tasks (e.g., MT) haven't been tested, yet. if self._tokenizer_name != "SplitChars" and self._tokenizer_name != "dummy_tokenizer_name": raise NotImplementedError("For now, Seq2SeqTask only supports character-level tasks.") def load_data(self): # Data is exposed as iterable: no preloading pass def get_split_text(self, split: str): """ Get split text as iterable of records. Split should be one of 'train', 'val', or 'test'. """ return self.get_data_iter(self.files_by_split[split]) def get_all_labels(self) -> List[str]: """ Build character vocabulary and return it as a list """ token2freq = collections.Counter() for split in ["train", "val"]: for _, sequence in self.get_data_iter(self.files_by_split[split]): for token in sequence: token2freq[token] += 1 return [t for t, _ in token2freq.most_common(self.max_targ_v_size)] def get_data_iter(self, path): """ Load data """ with codecs.open(path, "r", "utf-8", errors="ignore") as txt_fh: for row in txt_fh: row = row.strip().split("\t") if len(row) < 2 or not row[0] or not row[1]: continue src_sent = tokenize_and_truncate(self._tokenizer_name, row[0], self.max_seq_len) tgt_sent = tokenize_and_truncate(self._tokenizer_name, row[2], self.max_seq_len) yield (src_sent, tgt_sent) def get_sentences(self) -> Iterable[Sequence[str]]: """ Yield sentences, used to compute vocabulary. """ for split in self.files_by_split: # Don't use test set for vocab building. if split.startswith("test"): continue path = self.files_by_split[split] yield from self.get_data_iter(path) def count_examples(self): """ Compute here b/c we're streaming the sentences. """ example_counts = {} for split, split_path in self.files_by_split.items(): example_counts[split] = sum( 1 for _ in codecs.open(split_path, "r", "utf-8", errors="ignore") ) self.example_counts = example_counts def process_split( self, split, indexers, model_preprocessing_interface ) -> Iterable[Type[Instance]]: """ Process split text into a list of AllenNLP Instances. """ def _make_instance(input_, target): d = { "inputs": sentence_to_text_field( model_preprocessing_interface.boundary_token_fn(input_), indexers ), "targs": sentence_to_text_field( model_preprocessing_interface.boundary_token_fn(target), self.target_indexer ), } return Instance(d) for sent1, sent2 in split: yield _make_instance(sent1, sent2) def get_metrics(self, reset=False): """Get metrics specific to the task""" avg_nll = self.scorer1.get_metric(reset) acc = self.scorer2.get_metric(reset) return {"perplexity": math.exp(avg_nll), "accuracy": acc} def update_metrics(self, logits, labels, tagmask=None): self.scorer2(logits.max(2)[1], labels, tagmask) return def get_prediction(self, voc_src, voc_trg, inputs, gold, output): tokenizer = get_tokenizer(self._tokenizer_name) input_string = tokenizer.detokenize([voc_src[token.item()] for token in inputs]).split( "<EOS>" )[0] gold_string = tokenizer.detokenize([voc_trg[token.item()] for token in gold]).split( "<EOS>" )[0] output_string = tokenizer.detokenize([voc_trg[token.item()] for token in output]).split( "<EOS>" )[0] return input_string, gold_string, output_string
class BertMCQAModel(Model): """ """ def __init__(self, vocab: Vocabulary, pretrained_model: str = None, requires_grad: bool = True, top_layer_only: bool = True, bert_weights_model: str = None, per_choice_loss: bool = False, layer_freeze_regexes: List[str] = None, regularizer: Optional[RegularizerApplicator] = None, use_comparative_bert: bool = True, use_bilinear_classifier: bool = False, train_comparison_layer: bool = False, number_of_choices_compared: int = 0, comparison_layer_hidden_size: int = -1, comparison_layer_use_relu: bool = True) -> None: super().__init__(vocab, regularizer) self._use_comparative_bert = use_comparative_bert self._use_bilinear_classifier = use_bilinear_classifier self._train_comparison_layer = train_comparison_layer if train_comparison_layer: assert number_of_choices_compared > 1 self._num_choices = number_of_choices_compared self._comparison_layer_hidden_size = comparison_layer_hidden_size self._comparison_layer_use_relu = comparison_layer_use_relu # Bert weights and config if bert_weights_model: logging.info(f"Loading BERT weights model from {bert_weights_model}") bert_model_loaded = load_archive(bert_weights_model) self._bert_model = bert_model_loaded.model._bert_model else: self._bert_model = BertModel.from_pretrained(pretrained_model) for param in self._bert_model.parameters(): param.requires_grad = requires_grad #for name, param in self._bert_model.named_parameters(): # grad = requires_grad # if layer_freeze_regexes and grad: # grad = not any([bool(re.search(r, name)) for r in layer_freeze_regexes]) # param.requires_grad = grad bert_config = self._bert_model.config self._output_dim = bert_config.hidden_size self._dropout = torch.nn.Dropout(bert_config.hidden_dropout_prob) self._per_choice_loss = per_choice_loss # Bert Classifier selector final_output_dim = 1 if not use_comparative_bert: if bert_weights_model and hasattr(bert_model_loaded.model, "_classifier"): self._classifier = bert_model_loaded.model._classifier else: self._classifier = Linear(self._output_dim, final_output_dim) else: if use_bilinear_classifier: self._classifier = Bilinear(self._output_dim, self._output_dim, final_output_dim) else: self._classifier = Linear(self._output_dim * 2, final_output_dim) self._classifier.apply(self._bert_model.init_bert_weights) # Comparison layer setup if self._train_comparison_layer: number_of_pairs = self._num_choices * (self._num_choices - 1) if self._comparison_layer_hidden_size == -1: self._comparison_layer_hidden_size = number_of_pairs * number_of_pairs self._comparison_layer_1 = Linear(number_of_pairs, self._comparison_layer_hidden_size) if self._comparison_layer_use_relu: self._comparison_layer_1_activation = torch.nn.LeakyReLU() else: self._comparison_layer_1_activation = torch.nn.Tanh() self._comparison_layer_2 = Linear(self._comparison_layer_hidden_size, self._num_choices) self._comparison_layer_2_activation = torch.nn.Softmax() # Scalar mix, if necessary self._all_layers = not top_layer_only if self._all_layers: if bert_weights_model and hasattr(bert_model_loaded.model, "_scalar_mix") \ and bert_model_loaded.model._scalar_mix is not None: self._scalar_mix = bert_model_loaded.model._scalar_mix else: num_layers = bert_config.num_hidden_layers initial_scalar_parameters = num_layers * [0.0] initial_scalar_parameters[-1] = 5.0 # Starts with most mass on last layer self._scalar_mix = ScalarMix(bert_config.num_hidden_layers, initial_scalar_parameters=initial_scalar_parameters, do_layer_norm=False) else: self._scalar_mix = None # Accuracy and loss setup if self._train_comparison_layer: self._accuracy = CategoricalAccuracy() self._loss = torch.nn.CrossEntropyLoss() else: self._accuracy = BooleanAccuracy() self._loss = torch.nn.BCEWithLogitsLoss() self._debug = -1 def _extract_last_token_pooled_output(self, encoded_layers, question_mask): """ Extract the output vector for the last token in the sentence - similarly to how pooled_output is extracted for us when calling 'bert_model'. We need the question mask to find the last actual (non-padding) token :return: """ if self._all_layers: encoded_layers = encoded_layers[-1] # A cool trick to extract the last "True" item in each row question_mask = question_mask.squeeze() # We already asserted this at batch_size == 1, but why not assert question_mask.dim() == 2 shifted_matrix = question_mask.roll(-1, 1) shifted_matrix[:, -1] = 0 last_item_indices = question_mask - shifted_matrix # TODO: This row, for some reason, didn't work as expected, but it is much better then the implementation that follows # last_token_tensor = encoded_layers[last_item_indices] num_pairs, token_number, hidden_size = encoded_layers.size() assert last_item_indices.size() == (num_pairs, token_number) # Don't worry, expand doesn't allocate new memory, it simply views the tensor differently expanded_last_item_indices = last_item_indices.unsqueeze(2).expand(num_pairs, token_number, hidden_size) last_token_tensor = encoded_layers.masked_select(expanded_last_item_indices.byte()) last_token_tensor = last_token_tensor.reshape(num_pairs, hidden_size) pooled_output = self._bert_model.pooler.dense(last_token_tensor) pooled_output = self._bert_model.pooler.activation(pooled_output) return pooled_output def forward(self, question: Dict[str, torch.LongTensor], choice1_indexes: List[int] = None, choice2_indexes: List[int] = None, label: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> torch.Tensor: self._debug -= 1 input_ids = question['bert'] # input_ids.size() == (batch_size, num_pairs, max_sentence_length) batch_size, num_pairs, _ = question['bert'].size() question_mask = (input_ids != 0).long() if self._train_comparison_layer: assert num_pairs == self._num_choices * (self._num_choices - 1) # Segment ids real_segment_ids = question['bert-type-ids'].clone() # Change the last 'SEP' to belong to the second answer (for symmetry) last_seps = (real_segment_ids.roll(-1) == 2) & (real_segment_ids == 1) real_segment_ids[last_seps] = 2 # Update segment ids so that they are '1' for answers and '0' for the question real_segment_ids = (real_segment_ids == 0) | (real_segment_ids == 2) real_segment_ids = real_segment_ids.long() # TODO: How to extract last token pooled output if batch size != 1 assert batch_size == 1 # Run model encoded_layers, first_vectors_pooled_output = self._bert_model(input_ids=util.combine_initial_dims(input_ids), token_type_ids=util.combine_initial_dims(real_segment_ids), attention_mask=util.combine_initial_dims(question_mask), output_all_encoded_layers=self._all_layers) if self._use_comparative_bert: last_vectors_pooled_output = self._extract_last_token_pooled_output(encoded_layers, question_mask) else: last_vectors_pooled_output = None if self._all_layers: mixed_layer = self._scalar_mix(encoded_layers, question_mask) first_vectors_pooled_output = self._bert_model.pooler(mixed_layer) # Apply dropout first_vectors_pooled_output = self._dropout(first_vectors_pooled_output) if self._use_comparative_bert: last_vectors_pooled_output = self._dropout(last_vectors_pooled_output) # Classify if not self._use_comparative_bert: pair_label_logits = self._classifier(first_vectors_pooled_output) else: if self._use_bilinear_classifier: pair_label_logits = self._classifier(first_vectors_pooled_output, last_vectors_pooled_output) else: all_pooled_output = torch.cat((first_vectors_pooled_output, last_vectors_pooled_output), 1) pair_label_logits = self._classifier(all_pooled_output) pair_label_logits = pair_label_logits.view(-1, num_pairs) pair_label_probs = torch.sigmoid(pair_label_logits) output_dict = {} pair_label_probs_flat = pair_label_probs.squeeze(1) output_dict['pair_label_probs'] = pair_label_probs_flat.view(-1, num_pairs) output_dict['pair_label_logits'] = pair_label_logits output_dict['choice1_indexes'] = choice1_indexes output_dict['choice2_indexes'] = choice2_indexes if not self._train_comparison_layer: if label is not None: label = label.unsqueeze(1) label = label.expand(-1, num_pairs) relevant_pairs = (choice1_indexes == label) | (choice2_indexes == label) relevant_probs = pair_label_probs[relevant_pairs] choice1_is_the_label = (choice1_indexes == label)[relevant_pairs] # choice1_is_the_label = choice1_is_the_label.type_as(relevant_logits) loss = self._loss(relevant_probs, choice1_is_the_label.float()) self._accuracy(relevant_probs >= 0.5, choice1_is_the_label) output_dict["loss"] = loss return output_dict else: choice_logits = self._comparison_layer_2(self._comparison_layer_1_activation(self._comparison_layer_1( pair_label_probs))) output_dict['choice_logits'] = choice_logits output_dict['choice_probs'] = torch.softmax(choice_logits, 1) output_dict['predicted_choice'] = torch.argmax(choice_logits, 1) if label is not None: loss = self._loss(choice_logits, label) self._accuracy(choice_logits, label) output_dict["loss"] = loss return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { 'EM': self._accuracy.get_metric(reset), }
class MultiGranularityHierarchicalAttentionFusionNetworks(Model): def __init__( self, vocab: Vocabulary, elmo_embedder: TextFieldEmbedder, tokens_embedder: TextFieldEmbedder, features_embedder: TextFieldEmbedder, phrase_layer: Seq2SeqEncoder, projected_layer: Seq2SeqEncoder, contextual_passage: Seq2SeqEncoder, contextual_question: Seq2SeqEncoder, dropout: float = 0.2, regularizer: Optional[RegularizerApplicator] = None, initializer: InitializerApplicator = InitializerApplicator(), ): super(MultiGranularityHierarchicalAttentionFusionNetworks, self).__init__(vocab, regularizer) self.elmo_embedder = elmo_embedder self.tokens_embedder = tokens_embedder self.features_embedder = features_embedder self._phrase_layer = phrase_layer self._encoding_dim = self._phrase_layer.get_output_dim() self.projected_layer = torch.nn.Linear(self._encoding_dim + 1024, self._encoding_dim) self.fuse_p = FusionLayer(self._encoding_dim) self.fuse_q = FusionLayer(self._encoding_dim) self.fuse_s = FusionLayer(self._encoding_dim) self.projected_lstm = projected_layer self.contextual_layer_p = contextual_passage self.contextual_layer_q = contextual_question self.linear_self_align = torch.nn.Linear(self._encoding_dim, 1) # self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y') self._self_attention = BilinearMatrixAttention(self._encoding_dim, self._encoding_dim) self.bilinear_layer_s = BilinearSeqAtt(self._encoding_dim, self._encoding_dim) self.bilinear_layer_e = BilinearSeqAtt(self._encoding_dim, self._encoding_dim) self.yesno_predictor = FeedForward(self._encoding_dim, self._encoding_dim, 3) self.relu = torch.nn.ReLU() self._max_span_length = 30 self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() self._span_yesno_accuracy = CategoricalAccuracy() self._official_f1 = Average() self._variational_dropout = InputVariationalDropout(dropout) self._loss = torch.nn.CrossEntropyLoss() initializer(self) def forward( self, question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, yesno_list: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: batch_size, max_qa_count, max_q_len, _ = question[ 'token_characters'].size() total_qa_count = batch_size * max_qa_count qa_mask = torch.ge(yesno_list, 0).view(total_qa_count) # GloVe and simple cnn char embedding, embedding dim = 100 + 100 = 200 word_emb_ques = self.tokens_embedder( question, num_wrapping_dims=1).reshape(total_qa_count, max_q_len, self.tokens_embedder.get_output_dim()) word_emb_pass = self.tokens_embedder(passage) # Elmo embedding, embedding dim = 1024 elmo_ques = self.elmo_embedder(question, num_wrapping_dims=1).reshape( total_qa_count, max_q_len, self.elmo_embedder.get_output_dim()) elmo_pass = self.elmo_embedder(passage) # Passage features embedding, embedding dim = 20 + 20 = 40 pass_feat = self.features_embedder(passage) # GloVe + cnn + Elmo embedded_question = self._variational_dropout( torch.cat([word_emb_ques, elmo_ques], dim=2)) embedded_passage = self._variational_dropout( torch.cat([word_emb_pass, elmo_pass], dim=2)) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float() question_mask = question_mask.reshape(total_qa_count, max_q_len) passage_mask = util.get_text_field_mask(passage).float() repeated_passage_mask = passage_mask.unsqueeze(1).repeat( 1, max_qa_count, 1) repeated_passage_mask = repeated_passage_mask.view( total_qa_count, passage_length) # Concatenate Elmo after encoded passage encode_passage = self._phrase_layer(embedded_passage, passage_mask) projected_passage = self.relu( self.projected_layer(torch.cat([encode_passage, elmo_pass], dim=2))) # Concatenate Elmo after encoded question encode_question = self._phrase_layer(embedded_question, question_mask) projected_question = self.relu( self.projected_layer(torch.cat([encode_question, elmo_ques], dim=2))) encoded_passage = self._variational_dropout(projected_passage) repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat( 1, max_qa_count, 1, 1) repeated_encoded_passage = repeated_encoded_passage.view( total_qa_count, passage_length, self._encoding_dim) repeated_pass_feat = (pass_feat.unsqueeze(1).repeat( 1, max_qa_count, 1, 1)).view(total_qa_count, passage_length, 40) encoded_question = self._variational_dropout(projected_question) # total_qa_count * max_q_len * passage_length # cnt * m * n s = torch.bmm(encoded_question, repeated_encoded_passage.transpose(2, 1)) alpha = util.masked_softmax(s, question_mask.unsqueeze(2).expand( s.size()), dim=1) # cnt * n * h aligned_p = torch.bmm(alpha.transpose(2, 1), encoded_question) # cnt * m * n beta = util.masked_softmax(s, repeated_passage_mask.unsqueeze(1).expand( s.size()), dim=2) # cnt * m * h aligned_q = torch.bmm(beta, repeated_encoded_passage) fused_p = self.fuse_p(repeated_encoded_passage, aligned_p) fused_q = self.fuse_q(encoded_question, aligned_q) # add manual features here q_aware_p = self._variational_dropout( self.projected_lstm( torch.cat([fused_p, repeated_pass_feat], dim=2), repeated_passage_mask)) # cnt * n * n # self_p = torch.bmm(q_aware_p, q_aware_p.transpose(2, 1)) # self_p = self.bilinear_self_align(q_aware_p) self_p = self._self_attention(q_aware_p, q_aware_p) mask = repeated_passage_mask.reshape( total_qa_count, passage_length, 1) * repeated_passage_mask.reshape( total_qa_count, 1, passage_length) self_mask = torch.eye(passage_length, passage_length, device=self_p.device) self_mask = self_mask.reshape(1, passage_length, passage_length) mask = mask * (1 - self_mask) lamb = util.masked_softmax(self_p, mask, dim=2) # lamb = util.masked_softmax(self_p, repeated_passage_mask, dim=2) # cnt * n * h self_aligned_p = torch.bmm(lamb, q_aware_p) # cnt * n * h fused_self_p = self.fuse_s(q_aware_p, self_aligned_p) contextual_p = self._variational_dropout( self.contextual_layer_p(fused_self_p, repeated_passage_mask)) # contextual_p = self.contextual_layer_p(fused_self_p, repeated_passage_mask) contextual_q = self._variational_dropout( self.contextual_layer_q(fused_q, question_mask)) # contextual_q = self.contextual_layer_q(fused_q, question_mask) # cnt * m gamma = util.masked_softmax( self.linear_self_align(contextual_q).squeeze(2), question_mask, dim=1) # cnt * h weighted_q = torch.bmm(gamma.unsqueeze(1), contextual_q).squeeze(1) span_start_logits = self.bilinear_layer_s(weighted_q, contextual_p) span_end_logits = self.bilinear_layer_e(weighted_q, contextual_p) # cnt * n * 1 cnt * 1 * h span_yesno_logits = self.yesno_predictor( torch.bmm(span_end_logits.unsqueeze(2), weighted_q.unsqueeze(1))) # span_yesno_logits = self.yesno_predictor(contextual_p) span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7) best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits, span_yesno_logits, self._max_span_length) output_dict: Dict[str, Any] = {} # Compute the loss for training if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1), ignore_index=-1) self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask) loss += nll_loss(util.masked_log_softmax(span_end_logits, repeated_passage_mask), span_end.view(-1), ignore_index=-1) self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask) self._span_accuracy(best_span[:, 0:2], torch.stack([span_start, span_end], -1).view(total_qa_count, 2), mask=qa_mask.unsqueeze(1).expand(-1, 2).long()) # add a select for the right span to compute loss gold_span_end_loc = [] span_end = span_end.view( total_qa_count).squeeze().data.cpu().numpy() for i in range(0, total_qa_count): gold_span_end_loc.append( max(span_end[i] * 3 + i * passage_length * 3, 0)) gold_span_end_loc.append( max(span_end[i] * 3 + i * passage_length * 3 + 1, 0)) gold_span_end_loc.append( max(span_end[i] * 3 + i * passage_length * 3 + 2, 0)) gold_span_end_loc = span_start.new(gold_span_end_loc) pred_span_end_loc = [] for i in range(0, total_qa_count): pred_span_end_loc.append( max(best_span[i][1] * 3 + i * passage_length * 3, 0)) pred_span_end_loc.append( max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0)) pred_span_end_loc.append( max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0)) predicted_end = span_start.new(pred_span_end_loc) _yesno = span_yesno_logits.view(-1).index_select( 0, gold_span_end_loc).view(-1, 3) loss += nll_loss(torch.nn.functional.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1) _yesno = span_yesno_logits.view(-1).index_select( 0, predicted_end).view(-1, 3) self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. output_dict['best_span_str'] = [] output_dict['qid'] = [] output_dict['yesno'] = [] best_span_cpu = best_span.detach().cpu().numpy() for i in range(batch_size): passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] f1_score = 0.0 per_dialog_best_span_list = [] per_dialog_yesno_list = [] per_dialog_query_id_list = [] for per_dialog_query_index, (iid, answer_texts) in enumerate( zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])): predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index]) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] yesno_pred = predicted_span[2] per_dialog_yesno_list.append(yesno_pred) per_dialog_query_id_list.append(iid) best_span_string = passage_str[start_offset:end_offset] per_dialog_best_span_list.append(best_span_string) if answer_texts: if len(answer_texts) > 1: t_f1 = [] # Compute F1 over N-1 human references and averages the scores. for answer_index in range(len(answer_texts)): idxes = list(range(len(answer_texts))) idxes.pop(answer_index) refs = [answer_texts[z] for z in idxes] t_f1.append( squad_eval.metric_max_over_ground_truths( squad_eval.f1_score, best_span_string, refs)) f1_score = 1.0 * sum(t_f1) / len(t_f1) else: f1_score = squad_eval.metric_max_over_ground_truths( squad_eval.f1_score, best_span_string, answer_texts) self._official_f1(100 * f1_score) output_dict['qid'].append(per_dialog_query_id_list) output_dict['best_span_str'].append(per_dialog_best_span_list) output_dict['yesno'].append(per_dialog_yesno_list) return output_dict def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: yesno_tags = [[ self.vocab.get_token_from_index(x, namespace="yesno_labels") for x in yn_list ] for yn_list in output_dict.pop("yesno")] output_dict['yesno'] = yesno_tags return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'yesno': self._span_yesno_accuracy.get_metric(reset), 'f1': self._official_f1.get_metric(reset), } @staticmethod def _get_best_span_yesno_followup(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor, span_yesno_logits: torch.Tensor, max_span_length: int) -> torch.Tensor: if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 3), dtype=torch.long) span_start_logits = span_start_logits.data.cpu().numpy() span_end_logits = span_end_logits.data.cpu().numpy() span_yesno_logits = span_yesno_logits.data.cpu().numpy() for b_i in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b_i, span_start_argmax[b_i]] if val1 < span_start_logits[b_i, j]: span_start_argmax[b_i] = j val1 = span_start_logits[b_i, j] val2 = span_end_logits[b_i, j] if val1 + val2 > max_span_log_prob[b_i]: if j - span_start_argmax[b_i] > max_span_length: continue best_word_span[b_i, 0] = span_start_argmax[b_i] best_word_span[b_i, 1] = j max_span_log_prob[b_i] = val1 + val2 for b_i in range(batch_size): j = best_word_span[b_i, 1] yesno_pred = np.argmax(span_yesno_logits[b_i, j]) best_word_span[b_i, 2] = int(yesno_pred) return best_word_span
class BidafV4(Model): """ MODIFICATION NOTE: This class is a modification of BiDAF. In here we try to see what happens to our results if we convert the question encoder into a simple term frequency (bag-of-words) encoder which disregards word order. By doing so we analyze whether BiDAF can learn to solve SQuAD without having to encode the question sequentially. It has been shown in previous work that BiDAF and other models trained on SQuAD do not focus on questions words as we would expect them to. For example, they will often focus ORIGINAL DOCSTRING: This class implements Minjoon Seo's `Bidirectional Attention Flow model <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_ for answering reading comprehension questions (ICLR 2017). The basic layout is pretty simple: encode words as a combination of word embeddings and a character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of attentions to put question information into the passage word representations (this is the only part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and do a softmax over span start and span end. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. num_highway_layers : ``int`` The number of highway layers to use in between embedding the input and passing it through the phrase layer. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. similarity_function : ``SimilarityFunction`` The similarity function that we will use when comparing encoded passage and question representations. modeling_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). mask_lstms : ``bool``, optional (default=True) If ``False``, we will skip passing the mask to the LSTM layers. This gives a ~2x speedup, with only a slight performance decrease, if any. We haven't experimented much with this yet, but have confirmed that we still get very similar performance with much faster training times. We still use the mask for all softmaxes, but avoid the shuffling that's required when using masking with pytorch LSTMs. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, similarity_function: SimilarityFunction, modeling_layer: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, dropout: float = 0.2, mask_lstms: bool = True, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(BidafV4, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._highway_layer = TimeDistributed( Highway(text_field_embedder.get_output_dim(), num_highway_layers)) self._phrase_layer = phrase_layer self._matrix_attention = LegacyMatrixAttention(similarity_function) self._modeling_layer = modeling_layer self._span_end_encoder = span_end_encoder encoding_dim = phrase_layer.get_output_dim() modeling_dim = modeling_layer.get_output_dim() span_start_input_dim = encoding_dim * 4 + modeling_dim self._span_start_predictor = TimeDistributed( torch.nn.Linear(span_start_input_dim, 1)) span_end_encoding_dim = span_end_encoder.get_output_dim() span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim self._span_end_predictor = TimeDistributed( torch.nn.Linear(span_end_input_dim, 1)) # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily # obvious from the configuration files, so we check here. check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim, "modeling layer input dim", "4 * encoding dim") check_dimensions_match(text_field_embedder.get_output_dim(), phrase_layer.get_input_dim(), "text field embedder output dim", "phrase layer input dim") check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim, "span end encoder input dim", "4 * encoding dim + 3 * modeling dim") self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms initializer(self) def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer( self._text_field_embedder(question)) embedded_passage = self._highway_layer( self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) # # v5: # # remember to set token embeddings in the CONFIG JSON # encoded_question = self._dropout(embedded_question) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) -- SIMILARITY MATRIX similarity_matrix = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) -- CONTEXT2QUERY passage_question_attention = util.last_dim_softmax( similarity_matrix, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # Our custom query2context q2c_attention = util.masked_softmax(similarity_matrix, question_mask, dim=1).transpose(-1, -2) q2c_vecs = util.weighted_sum(encoded_passage, q2c_attention) # Now we try the various variants # v1: # tiled_question_passage_vector = util.weighted_sum(q2c_vecs, passage_question_attention) # v2: # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], encoded_passage.shape[1])) # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).transpose(-1, -2) # v3: # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], 1)) # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).squeeze().unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # v4: # Re-application of query2context attention new_similarity_matrix = self._matrix_attention(encoded_passage, q2c_vecs) masked_similarity = util.replace_masked_values( new_similarity_matrix, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) # ------- Original variant # # We replace masked values with something really negative here, so they don't affect the # # max below. # masked_similarity = util.replace_masked_values(similarity_matrix, # question_mask.unsqueeze(1), # -1e7) # # Shape: (batch_size, passage_length) # question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # # Shape: (batch_size, passage_length) # question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # # Shape: (batch_size, encoding_dim) # question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # # Shape: (batch_size, passage_length, encoding_dim) # tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, # passage_length, # encoding_dim) # ------- END # Shape: (batch_size, passage_length, encoding_dim * 4) # original beta combination function final_merged_passage = torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector ], dim=-1) # # v6: # final_merged_passage = torch.cat([tiled_question_passage_vector], # dim=-1) # # # v7: # final_merged_passage = torch.cat([passage_question_vectors], # dim=-1) # # # v8: # final_merged_passage = torch.cat([passage_question_vectors, # tiled_question_passage_vector], # dim=-1) # # # v9: # final_merged_passage = torch.cat([encoded_passage, # passage_question_vectors, # encoded_passage * passage_question_vectors], # dim=-1) modeled_passage = self._dropout( self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout( torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze( 1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([ final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation ], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout( self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout( torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } @staticmethod def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor: if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long) span_start_logits = span_start_logits.detach().cpu().numpy() span_end_logits = span_end_logits.detach().cpu().numpy() for b in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b, span_start_argmax[b]] if val1 < span_start_logits[b, j]: span_start_argmax[b] = j val1 = span_start_logits[b, j] val2 = span_end_logits[b, j] if val1 + val2 > max_span_log_prob[b]: best_word_span[b, 0] = span_start_argmax[b] best_word_span[b, 1] = j max_span_log_prob[b] = val1 + val2 return best_word_span
class TweetJointly(Model): def __init__( self, vocab: Vocabulary, transformer_model_name: str = "bert-base-uncased", feedforward: Optional[FeedForward] = None, smoothing: bool = False, smooth_alpha: float = 0.7, sentiment_task: bool = False, sentiment_task_weight: float = 1.0, sentiment_classification_with_label: bool = True, sentiment_seq2vec: Optional[Seq2VecEncoder] = None, candidate_span_task: bool = False, candidate_span_task_weight: float = 1.0, candidate_delay: int = 30000, candidate_span_num: int = 5, candidate_classification_layer_units: int = 128, candidate_span_extractor: Optional[SpanExtractor] = None, candidate_span_with_logits: bool = False, dropout: Optional[float] = None, **kwargs, ) -> None: super().__init__(vocab, **kwargs) if "BERTweet" not in transformer_model_name: self._text_field_embedder = BasicTextFieldEmbedder({ "tokens": PretrainedTransformerEmbedder(transformer_model_name) }) else: self._text_field_embedder = BasicTextFieldEmbedder( {"tokens": TweetBertEmbedder(transformer_model_name)}) # span start & end task if feedforward is None: self._linear_layer = nn.Sequential( nn.Linear(self._text_field_embedder.get_output_dim(), 128), nn.ReLU(), nn.Linear(128, 2), ) else: self._linear_layer = feedforward self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._jaccard = Jaccard() self._candidate_delay = candidate_delay self._delay = 0 self._smoothing = smoothing self._smooth_alpha = smooth_alpha if smoothing: self._loss = nn.KLDivLoss(reduction="batchmean") else: self._loss = nn.CrossEntropyLoss() # sentiment task self._sentiment_task = sentiment_task if self._sentiment_task: self._sentiment_classification_accuracy = CategoricalAccuracy() self._sentiment_loss_log = LossLog() self.register_buffer("sentiment_task_weight", torch.tensor(sentiment_task_weight)) self._sentiment_classification_with_label = ( sentiment_classification_with_label) if sentiment_seq2vec is None: raise ConfigurationError( "sentiment task is True, we need a sentiment seq2vec encoder" ) else: self._sentiment_encoder = sentiment_seq2vec self._sentiment_linear = nn.Linear( self._sentiment_encoder.get_output_dim(), vocab.get_vocab_size("labels"), ) # candidate span task self._candidate_span_task = candidate_span_task if candidate_span_task: assert candidate_span_num > 0 assert candidate_span_task_weight > 0 assert candidate_classification_layer_units > 0 self._candidate_span_num = candidate_span_num self.register_buffer("candidate_span_task_weight", torch.tensor(candidate_span_task_weight)) self._candidate_classification_layer_units = ( candidate_classification_layer_units) self._span_classification_accuracy = CategoricalAccuracy() self._candidate_loss_log = LossLog() self._candidate_span_linear = nn.Linear( self._text_field_embedder.get_output_dim(), self._candidate_classification_layer_units, ) if candidate_span_extractor is None: self._candidate_span_extractor = EndpointSpanExtractor( input_dim=self._candidate_classification_layer_units) else: self._candidate_span_extractor = candidate_span_extractor if candidate_span_with_logits: self._candidate_with_logits = True self._candidate_span_vec_linear = nn.Linear( self._candidate_span_extractor.get_output_dim() + 1, 1) else: self._candidate_with_logits = False self._candidate_span_vec_linear = nn.Linear( self._candidate_span_extractor.get_output_dim(), 1) self._candidate_jaccard = Jaccard() if sentiment_task or candidate_span_task: self._base_loss_log = LossLog() else: self._base_loss_log = None if dropout is not None: self._dropout = nn.Dropout(dropout) else: self._dropout = None def forward( # type: ignore self, text: Dict[str, Dict[str, torch.LongTensor]], sentiment: torch.IntTensor, text_with_sentiment: Dict[str, Dict[str, torch.LongTensor]], text_span: torch.IntTensor, selected_text_span: Optional[torch.IntTensor] = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: # batch_size * text_length * hidden_dims embedded_question = self._text_field_embedder(text_with_sentiment) if self._dropout is not None: embedded_question = self._dropout(embedded_question) self._delay += int(embedded_question.size(0)) # span start & span end task logits = self._linear_layer(embedded_question) span_start_logits, span_end_logits = logits.split(1, dim=-1) span_start_logits = span_start_logits.squeeze(-1) span_end_logits = span_end_logits.squeeze(-1) possible_answer_mask = torch.zeros_like( util.get_token_ids_from_text_field_tensors( text_with_sentiment)).bool() for i, (start, end) in enumerate(text_span): possible_answer_mask[i, start:end + 1] = True span_start_logits = util.replace_masked_values(span_start_logits, possible_answer_mask, -1e32) span_end_logits = util.replace_masked_values(span_end_logits, possible_answer_mask, -1e32) span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1) span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1) best_spans = get_best_span(span_start_logits, span_end_logits) best_span_scores = torch.gather( span_start_logits, 1, best_spans[:, 0].unsqueeze(1)) + torch.gather( span_end_logits, 1, best_spans[:, 1].unsqueeze(1)) best_span_scores = best_span_scores.squeeze(1) output_dict = { "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_spans, "best_span_scores": best_span_scores, } loss = torch.tensor(0.0).to(embedded_question.device) # sentiment task if self._sentiment_task: if self._sentiment_classification_with_label: global_context_vec = self._sentiment_encoder(embedded_question) else: embedded_only_text = self._text_field_embedder(text) if self._dropout is not None: embedded_only_text = self._dropout(embedded_only_text) global_context_vec = self._sentiment_encoder( embedded_only_text) sentiment_logits = self._sentiment_linear(global_context_vec) sentiment_probs = torch.softmax(sentiment_logits, dim=-1) self._sentiment_classification_accuracy(sentiment_probs, sentiment) sentiment_loss = cross_entropy(sentiment_logits, sentiment) self._sentiment_loss_log(sentiment_loss) loss.add_(self.sentiment_task_weight * sentiment_loss) predict_sentiment_idx = sentiment_probs.argmax(dim=-1) sentiment_predicts = [] for i in predict_sentiment_idx.tolist(): sentiment_predicts.append( self.vocab.get_token_from_index(i, "labels")) output_dict["sentiment_logits"] = sentiment_logits output_dict["sentiment_probs"] = sentiment_probs output_dict["sentiment_predicts"] = sentiment_predicts # span classification if self._candidate_span_task and (self._delay >= self._candidate_delay): # shape: (batch_size, passage_length, embedding_dim) text_features_for_candidate = self._candidate_span_linear( embedded_question) text_features_for_candidate = torch.relu( text_features_for_candidate) with torch.no_grad(): # batch_size * candidate_num * 2 candidate_span = get_candidate_span(span_start_probs, span_end_probs, self._candidate_span_num) candidate_span_list = candidate_span.tolist() output_dict["candidate_spans"] = candidate_span_list if selected_text_span is not None: candidate_span, candidate_span_label = self.candidate_span_with_labels( candidate_span, selected_text_span) else: candidate_span_label = None # shape: (batch_size, candidate_num, span_extractor_output_dim) span_feature_vec = self._candidate_span_extractor( text_features_for_candidate, candidate_span) if self._candidate_with_logits: candidate_span_start_logits = torch.gather( span_start_logits, 1, candidate_span[:, :, 0]) candidate_span_end_logits = torch.gather( span_end_logits, 1, candidate_span[:, :, 1]) candidate_span_sum_logits = (candidate_span_start_logits + candidate_span_end_logits) span_feature_vec = torch.cat( (span_feature_vec, candidate_span_sum_logits.unsqueeze(2)), -1) # batch_size * candidate_num span_classification_logits = self._candidate_span_vec_linear( span_feature_vec).squeeze() span_classification_probs = torch.softmax( span_classification_logits, -1) output_dict[ "span_classification_probs"] = span_classification_probs candidate_best_span_idx = span_classification_probs.argmax(dim=-1) view_idx = ( candidate_best_span_idx + torch.arange(0, end=candidate_best_span_idx.shape[0]).to( candidate_best_span_idx.device) * self._candidate_span_num) candidate_span_view = candidate_span.view(-1, 2) candidate_best_spans = candidate_span_view.index_select( 0, view_idx) output_dict["candidate_best_spans"] = candidate_best_spans.tolist() if selected_text_span is not None: self._span_classification_accuracy(span_classification_probs, candidate_span_label) candidate_span_loss = cross_entropy(span_classification_logits, candidate_span_label) self._candidate_loss_log(candidate_span_loss) weighted_loss = self.candidate_span_task_weight * candidate_span_loss if candidate_span_loss > 1e2: print(f"candidate loss: {candidate_span_loss}") print( f"span_classification_logits: {span_classification_logits}" ) print(f"candidate_span_label: {candidate_span_label}") loss.add_(weighted_loss) candidate_best_spans = candidate_best_spans.detach().cpu().numpy() output_dict["best_candidate_span_str"] = [] for metadata_entry, best_span in zip(metadata, candidate_best_spans): text_with_sentiment_tokens = metadata_entry[ "text_with_sentiment_tokens"] predicted_start, predicted_end = tuple(best_span) if predicted_end >= len(text_with_sentiment_tokens): predicted_end = len(text_with_sentiment_tokens) - 1 best_span_string = self.span_tokens_to_text( metadata_entry["text"], text_with_sentiment_tokens, predicted_start, predicted_end, ) output_dict["best_candidate_span_str"].append(best_span_string) answers = metadata_entry.get("selected_text", "") if len(answers) > 0: self._candidate_jaccard(best_span_string, answers) # Compute the loss for training. if selected_text_span is not None: span_start = selected_text_span[:, 0] span_end = selected_text_span[:, 1] span_mask = span_start != -1 self._span_accuracy( best_spans, selected_text_span, span_mask.unsqueeze(-1).expand_as(best_spans), ) if not self._smoothing: start_loss = cross_entropy(span_start_logits, span_start, ignore_index=-1) if torch.any(start_loss > 1e9): logger.critical("Start loss too high (%r)", start_loss) logger.critical("span_start_logits: %r", span_start_logits) logger.critical("span_start: %r", span_start) logger.critical("text_with_sentiment: %r", text_with_sentiment) assert False end_loss = cross_entropy(span_end_logits, span_end, ignore_index=-1) if torch.any(end_loss > 1e9): logger.critical("End loss too high (%r)", end_loss) logger.critical("span_end_logits: %r", span_end_logits) logger.critical("span_end: %r", span_end) assert False else: sequence_length = span_start_logits.size(1) device = span_start.device start_distance = get_sequence_distance_from_span_endpoint( sequence_length, span_start) start_smooth_probs = torch.exp( start_distance * torch.log(torch.tensor(self._smooth_alpha).to(device))) start_smooth_probs = start_smooth_probs * possible_answer_mask start_smooth_probs = start_smooth_probs / start_smooth_probs.sum( -1, keepdim=True) span_start_log_probs = span_start_logits - torch.log( torch.exp(span_start_logits).sum(-1)).unsqueeze(-1) end_distance = get_sequence_distance_from_span_endpoint( sequence_length, span_end) end_smooth_probs = torch.exp( end_distance * torch.log(torch.tensor(self._smooth_alpha).to(device))) end_smooth_probs = end_smooth_probs * possible_answer_mask end_smooth_probs = end_smooth_probs / end_smooth_probs.sum( -1, keepdim=True) span_end_log_probs = span_end_logits - torch.log( torch.exp(span_end_logits).sum(-1)).unsqueeze(-1) # print(end_smooth_probs) # print(start_smooth_probs) # print(span_end_log_probs) # print(span_start_log_probs) start_loss = self._loss(span_start_log_probs, start_smooth_probs) end_loss = self._loss(span_end_log_probs, end_smooth_probs) span_start_end_loss = (start_loss + end_loss) / 2 if self._base_loss_log is not None: self._base_loss_log(span_start_end_loss) loss.add_(span_start_end_loss) self._span_start_accuracy(span_start_logits, span_start, span_mask) self._span_end_accuracy(span_end_logits, span_end, span_mask) output_dict["loss"] = loss # compute best span jaccard best_spans = best_spans.detach().cpu().numpy() output_dict["best_span_str"] = [] for metadata_entry, best_span in zip(metadata, best_spans): text_with_sentiment_tokens = metadata_entry[ "text_with_sentiment_tokens"] predicted_start, predicted_end = tuple(best_span) best_span_string = self.span_tokens_to_text( metadata_entry["text"], text_with_sentiment_tokens, predicted_start, predicted_end, ) output_dict["best_span_str"].append(best_span_string) answers = metadata_entry.get("selected_text", "") if len(answers) > 0: self._jaccard(best_span_string, answers) return output_dict # @staticmethod # def candidate_span_with_labels( # candidate_span: torch.Tensor, selected_text_span: torch.Tensor # ) -> Tuple[torch.Tensor, torch.Tensor]: # correct_span_idx = (candidate_span == selected_text_span.unsqueeze(1)).prod(-1) # candidate_span_adjust = torch.where( # ~(correct_span_idx.unsqueeze(-1) == 1), # candidate_span, # selected_text_span.unsqueeze(1), # ) # candidate_span_label = correct_span_idx.argmax(-1) # return candidate_span_adjust, candidate_span_label @staticmethod def candidate_span_with_labels( candidate_span: torch.Tensor, selected_text_span: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: candidate_span_label = batch_span_jaccard( candidate_span, selected_text_span).max(-1).indices return candidate_span, candidate_span_label @staticmethod def get_candidate_span_mask(candidate_span: torch.Tensor, passage_length: int) -> torch.Tensor: device = candidate_span.device batch_size, candidate_num = candidate_span.size()[:-1] candidate_span_mask = torch.zeros(batch_size, candidate_num, passage_length).to(device) for i in range(batch_size): for j in range(candidate_num): span_start, span_end = candidate_span[i][j] candidate_span_mask[i][j][span_start:span_end + 1] = 1 return candidate_span_mask @staticmethod def span_tokens_to_text(source_text, tokens, span_start, span_end): text_with_sentiment_tokens = tokens predicted_start = span_start predicted_end = span_end while (predicted_start >= 0 and text_with_sentiment_tokens[predicted_start].idx is None): predicted_start -= 1 if predicted_start < 0: logger.warning( f"Could not map the token '{text_with_sentiment_tokens[span_start].text}' at index " f"'{span_start}' to an offset in the original text.") character_start = 0 else: character_start = text_with_sentiment_tokens[predicted_start].idx while (predicted_end < len(text_with_sentiment_tokens) and text_with_sentiment_tokens[predicted_end].idx is None): predicted_end -= 1 if predicted_end >= len(text_with_sentiment_tokens): print(text_with_sentiment_tokens) print(len(text_with_sentiment_tokens)) print(span_end) print(predicted_end) logger.warning( f"Could not map the token '{text_with_sentiment_tokens[span_end].text}' at index " f"'{span_end}' to an offset in the original text.") character_end = len(source_text) else: end_token = text_with_sentiment_tokens[predicted_end] if end_token.idx == 0: character_end = (end_token.idx + len(sanitize_wordpiece(end_token.text)) + 1) else: character_end = end_token.idx + len( sanitize_wordpiece(end_token.text)) best_span_string = source_text[character_start:character_end].strip() return best_span_string def get_metrics(self, reset: bool = False) -> Dict[str, float]: jaccard = self._jaccard.get_metric(reset) metrics = { "start_acc": self._span_start_accuracy.get_metric(reset), "end_acc": self._span_end_accuracy.get_metric(reset), "span_acc": self._span_accuracy.get_metric(reset), "jaccard": jaccard, } if self._candidate_span_task: metrics[ "candidate_span_acc"] = self._span_classification_accuracy.get_metric( reset) metrics["candidate_jaccard"] = self._candidate_jaccard.get_metric( reset) metrics["candidate_loss"] = self._candidate_loss_log.get_metric( reset) if self._sentiment_task: metrics[ "sentiment_acc"] = self._sentiment_classification_accuracy.get_metric( reset) metrics["sentiment_loss"] = self._sentiment_loss_log.get_metric( reset) if self._base_loss_log is not None: metrics["base_loss"] = self._base_loss_log.get_metric(reset) return metrics
class EdgeProbingTask(Task): """ Generic class for fine-grained edge probing. Acts as a classifier, but with multiple targets for each input text. Targets are of the form (span1, span2, label), where span1 and span2 are half-open token intervals [i, j). Subclass this for each dataset, or use register_task with appropriate kw args. """ @property def _tokenizer_suffix(self): """ Suffix to make sure we use the correct source files, based on the given tokenizer. """ if self.tokenizer_name: return ".retokenized." + self.tokenizer_name else: return "" def tokenizer_is_supported(self, tokenizer_name): """ Check if the tokenizer is supported for this task. """ # Assume all tokenizers supported; if retokenized data not found # for this particular task, we'll just crash on file loading. return True def __init__( self, path: str, max_seq_len: int, name: str, label_file: str = None, files_by_split: Dict[str, str] = None, single_sided: bool = False, **kw, ): """Construct an edge probing task. path, max_seq_len, and name are passed by the code in preprocess.py; remaining arguments should be provided by a subclass constructor or via @register_task. Args: path: data directory max_seq_len: maximum sequence length (currently ignored) name: task name label_file: relative path to labels file files_by_split: split name ('train', 'val', 'test') mapped to relative filenames (e.g. 'train': 'train.json') single_sided: if true, only use span1. """ super().__init__(name, **kw) assert label_file is not None assert files_by_split is not None self._files_by_split = { split: os.path.join(path, fname) + self._tokenizer_suffix for split, fname in files_by_split.items() } self.path = path self.label_file = os.path.join(self.path, label_file) self.max_seq_len = max_seq_len self.single_sided = single_sided # Placeholders; see self.load_data() self._iters_by_split = None self.all_labels = None self.n_classes = None # see add_task_label_namespace in preprocess.py self._label_namespace = self.name + "_labels" # Scorers self.mcc_scorer = FastMatthews() self.acc_scorer = BooleanAccuracy() # binary accuracy self.f1_scorer = F1Measure(positive_label=1) # binary F1 overall self.val_metric = "%s_f1" % self.name # TODO: switch to MCC? self.val_metric_decreases = False def get_all_labels(self) -> List[str]: return self.all_labels @classmethod def _stream_records(cls, filename): skip_ctr = 0 total_ctr = 0 for record in utils.load_json_data(filename): total_ctr += 1 # Skip records with empty targets. # TODO(ian): don't do this if generating negatives! if not record.get("targets", None): skip_ctr += 1 continue yield record log.info( "Read=%d, Skip=%d, Total=%d from %s", total_ctr - skip_ctr, skip_ctr, total_ctr, filename, ) @staticmethod def merge_preds(record: Dict, preds: Dict) -> Dict: """ Merge predictions into record, in-place. List-valued predictions should align to targets, and are attached to the corresponding target entry. Non-list predictions are attached to the top-level record. """ record["preds"] = {} for target in record["targets"]: target["preds"] = {} for key, val in preds.items(): if isinstance(val, list): assert len(val) == len(record["targets"]) for i, target in enumerate(record["targets"]): target["preds"][key] = val[i] else: # non-list predictions, attach to top-level preds record["preds"][key] = val return record def load_data(self): self.all_labels = list(utils.load_lines(self.label_file)) self.n_classes = len(self.all_labels) iters_by_split = collections.OrderedDict() for split, filename in self._files_by_split.items(): # # Lazy-load using RepeatableIterator. # loader = functools.partial(utils.load_json_data, # filename=filename) # iter = serialize.RepeatableIterator(loader) iter = list(self._stream_records(filename)) iters_by_split[split] = iter self._iters_by_split = iters_by_split def get_split_text(self, split: str): """ Get split text as iterable of records. Split should be one of 'train', 'val', or 'test'. """ return self._iters_by_split[split] @classmethod def get_num_examples(cls, split_text): """ Return number of examples in the result of get_split_text. Subclass can override this if data is not stored in column format. """ return len(split_text) @classmethod def _make_span_field(cls, s, text_field, offset=1): return SpanField(s[0] + offset, s[1] - 1 + offset, text_field) def make_instance(self, record, idx, indexers, model_preprocessing_interface) -> Type[Instance]: """Convert a single record to an AllenNLP Instance.""" tokens = record["text"].split() # already space-tokenized by Moses tokens = model_preprocessing_interface.boundary_token_fn( tokens) # apply model-appropriate variants of [cls] and [sep]. text_field = sentence_to_text_field(tokens, indexers) d = {} d["idx"] = MetadataField(idx) d["input1"] = text_field d["span1s"] = ListField([ self._make_span_field(t["span1"], text_field, 1) for t in record["targets"] ]) if not self.single_sided: d["span2s"] = ListField([ self._make_span_field(t["span2"], text_field, 1) for t in record["targets"] ]) # Always use multilabel targets, so be sure each label is a list. labels = [ utils.wrap_singleton_string(t["label"]) for t in record["targets"] ] d["labels"] = ListField([ MultiLabelField(label_set, label_namespace=self._label_namespace, skip_indexing=False) for label_set in labels ]) return Instance(d) def process_split( self, records, indexers, model_preprocessing_interface) -> Iterable[Type[Instance]]: """ Process split text into a list of AllenNLP Instances. """ def _map_fn(r, idx): return self.make_instance(r, idx, indexers, model_preprocessing_interface) return map(_map_fn, records, itertools.count()) def get_sentences(self) -> Iterable[Sequence[str]]: """ Yield sentences, used to compute vocabulary. """ for split, iter in self._iters_by_split.items(): # Don't use test set for vocab building. if split.startswith("test"): continue for record in iter: yield record["text"].split() def get_metrics(self, reset=False): """Get metrics specific to the task""" metrics = {} metrics["mcc"] = self.mcc_scorer.get_metric(reset) metrics["acc"] = self.acc_scorer.get_metric(reset) precision, recall, f1 = self.f1_scorer.get_metric(reset) metrics["precision"] = precision metrics["recall"] = recall metrics["f1"] = f1 return metrics
class QaNetSemantic(Model): """ This class implements Adams Wei Yu's `QANet Model <https://openreview.net/forum?id=B14TlG-RW>`_ for machine reading comprehension published at ICLR 2018. The overall architecture of QANet is very similar to BiDAF. The main difference is that QANet replaces the RNN encoder with CNN + self-attention. There are also some minor differences in the modeling layer and output layer. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. num_highway_layers : ``int`` The number of highway layers to use in between embedding the input and passing it through the phrase layer. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the passage-question attention. matrix_attention_layer : ``MatrixAttention`` The matrix attention function that we will use when comparing encoded passage and question representations. modeling_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end. dropout_prob : ``float``, optional (default=0.1) If greater than 0, we will apply dropout with this probability between layers. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, matrix_attention_layer: MatrixAttention, modeling_layer: Seq2SeqEncoder, dropout_prob: float = 0.1, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super().__init__(vocab, regularizer) text_embed_dim = text_field_embedder.get_output_dim() encoding_in_dim = phrase_layer.get_input_dim() encoding_out_dim = phrase_layer.get_output_dim() modeling_in_dim = modeling_layer.get_input_dim() modeling_out_dim = modeling_layer.get_output_dim() self._text_field_embedder = text_field_embedder self._embedding_proj_layer = torch.nn.Linear(text_embed_dim, encoding_in_dim) self._highway_layer = Highway(encoding_in_dim, num_highway_layers) self._encoding_proj_layer = torch.nn.Linear(encoding_in_dim, encoding_in_dim) self._phrase_layer = phrase_layer self._matrix_attention = matrix_attention_layer self._modeling_proj_layer = torch.nn.Linear(encoding_out_dim * 4, modeling_in_dim) self._modeling_layer = modeling_layer self._span_start_predictor = torch.nn.Linear(modeling_out_dim * 2, 1) self._span_end_predictor = torch.nn.Linear(modeling_out_dim * 2, 1) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._metrics = SquadEmAndF1() self._dropout = torch.nn.Dropout( p=dropout_prob) if dropout_prob > 0 else lambda x: x # evaluation # BLEU self._bleu_score_types_to_use = ["BLEU1", "BLEU2", "BLEU3", "BLEU4"] self._bleu_scores = { x: Average() for x in self._bleu_score_types_to_use } # ROUGE using pyrouge self._rouge_score_types_to_use = ['rouge-n', 'rouge-l', 'rouge-w'] # if we have rouge-n as metric we actualy get n scores like rouge-1, rouge-2, .., rouge-n max_rouge_n = 4 rouge_n_metrics = [] if "rouge-n" in self._rouge_score_types_to_use: rouge_n_metrics = [ "rouge-{0}".format(x) for x in range(1, max_rouge_n + 1) ] rouge_scores_names = rouge_n_metrics + [ y for y in self._rouge_score_types_to_use if y != 'rouge-n' ] self._rouge_scores = {x: Average() for x in rouge_scores_names} self._rouge_evaluator = rouge.Rouge( metrics=self._rouge_score_types_to_use, max_n=max_rouge_n, limit_length=True, length_limit=100, length_limit_type='words', apply_avg=False, apply_best=False, alpha=0.5, # Default F1_score weight_factor=1.2, stemming=True) initializer(self) def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, passage_sem_views_q: torch.IntTensor = None, passage_sem_views_k: torch.IntTensor = None, question_sem_views_q: torch.IntTensor = None, question_sem_views_k: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. passage_sem_views_q : ``torch.IntTensor``, optional Paragraph semantic views features for multihead attention Query (Q) passage_sem_views_k : ``torch.IntTensor``, optional Paragraph semantic views features for multihead attention Key (K) question_sem_views_q : ``torch.IntTensor``, optional Paragraph semantic views features for multihead attention Query (Q) question_sem_views_k : ``torch.IntTensor``, optional Paragraph semantic views features for multihead attention Key (K) metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question tokens, passage tokens, original passage text, and token offsets into the passage for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() embedded_question = self._dropout(self._text_field_embedder(question)) embedded_passage = self._dropout(self._text_field_embedder(passage)) embedded_question = self._highway_layer( self._embedding_proj_layer(embedded_question)) embedded_passage = self._highway_layer( self._embedding_proj_layer(embedded_passage)) batch_size = embedded_question.size(0) projected_embedded_question = self._encoding_proj_layer( embedded_question) projected_embedded_passage = self._encoding_proj_layer( embedded_passage) if isinstance(self._phrase_layer, QaNetSemanticEncoder): passage_sem_views_q = passage_sem_views_q.long() passage_sem_views_k = passage_sem_views_k.long() question_sem_views_q = question_sem_views_q.long() question_sem_views_k = question_sem_views_k.long() encoded_passage = self._dropout( self._phrase_layer(projected_embedded_passage, passage_sem_views_q, passage_sem_views_k, passage_mask)) encoded_question = self._dropout( self._phrase_layer(projected_embedded_question, question_sem_views_q, question_sem_views_k, question_mask)) else: encoded_passage = self._dropout( self._phrase_layer(projected_embedded_passage, passage_mask)) encoded_question = self._dropout( self._phrase_layer(projected_embedded_question, question_mask)) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = masked_softmax( passage_question_similarity, question_mask, memory_efficient=True) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # Shape: (batch_size, question_length, passage_length) question_passage_attention = masked_softmax( passage_question_similarity.transpose(1, 2), passage_mask, memory_efficient=True) # Shape: (batch_size, passage_length, passage_length) attention_over_attention = torch.bmm(passage_question_attention, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) passage_passage_vectors = util.weighted_sum(encoded_passage, attention_over_attention) # Shape: (batch_size, passage_length, encoding_dim * 4) merged_passage_attention_vectors = self._dropout( torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * passage_passage_vectors ], dim=-1)) modeled_passage_list = [ self._modeling_proj_layer(merged_passage_attention_vectors) ] for _ in range(3): modeled_passage = self._dropout( self._modeling_layer(modeled_passage_list[-1], passage_mask)) modeled_passage_list.append(modeled_passage) # Shape: (batch_size, passage_length, modeling_dim * 2)) span_start_input = torch.cat( [modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) # Shape: (batch_size, passage_length, modeling_dim * 2) span_end_input = torch.cat( [modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e32) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e32) # Shape: (batch_size, passage_length) span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1) span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1) best_span = get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: try: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss except Exception as e: logging.exception(e) # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] all_reference_answers_text = [] all_best_spans = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) # offsets = metadata[i]['token_offsets'] # start_offset = offsets[predicted_span[0]][0] # end_offset = offsets[predicted_span[1]][1] start_span = predicted_span[0] end_span = predicted_span[1] best_span_tokens = metadata[i]['passage_tokens'][ start_span:end_span + 1] best_span_string = " ".join(best_span_tokens) output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._metrics(best_span_string, answer_texts) all_best_spans.append(best_span_string) all_reference_answers_text.append(answer_texts) if not self.training: self.calculate_rouge(all_best_spans, all_reference_answers_text) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict def calculate_rouge(self, predictions, references): # calculate rouge references_text = references predictions_text = predictions metrics_with_per_item_scores = self._rouge_evaluator.get_scores( predictions_text, references_text) for metric, results in sorted(metrics_with_per_item_scores.items(), key=lambda x: x[0]): for hypothesis_id, results_per_ref in enumerate(results): # we report the max f-score of the two answers curr_item_rouge_f = max(results_per_ref['f']) self._rouge_scores[metric](curr_item_rouge_f) def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._metrics.get_metric(reset) metrics = { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } # # report bleu scores # for k,v in self._bleu_scores.items(): # metrics[k] = v.get_metric(reset) for k, v in self._rouge_scores.items(): metrics[k] = v.get_metric(reset) return metrics
class BidirectionalAttentionFlow(Model): """ This class implements Minjoon Seo's `Bidirectional Attention Flow model <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_ for answering reading comprehension questions (ICLR 2017). The basic layout is pretty simple: encode words as a combination of word embeddings and a character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of attentions to put question information into the passage word representations (this is the only part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and do a softmax over span start and span end. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. num_highway_layers : ``int`` The number of highway layers to use in between embedding the input and passing it through the phrase layer. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. similarity_function : ``SimilarityFunction`` The similarity function that we will use when comparing encoded passage and question representations. modeling_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). mask_lstms : ``bool``, optional (default=True) If ``False``, we will skip passing the mask to the LSTM layers. This gives a ~2x speedup, with only a slight performance decrease, if any. We haven't experimented much with this yet, but have confirmed that we still get very similar performance with much faster training times. We still use the mask for all softmaxes, but avoid the shuffling that's required when using masking with pytorch LSTMs. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, similarity_function: SimilarityFunction, modeling_layer: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, dropout: float = 0.2, mask_lstms: bool = True, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._highway_layer = TimeDistributed( Highway(text_field_embedder.get_output_dim(), num_highway_layers)) self._phrase_layer = phrase_layer self._matrix_attention = LegacyMatrixAttention(similarity_function) self._modeling_layer = modeling_layer self._span_end_encoder = span_end_encoder encoding_dim = phrase_layer.get_output_dim() modeling_dim = modeling_layer.get_output_dim() span_start_input_dim = encoding_dim * 4 + modeling_dim self._span_start_predictor = TimeDistributed( torch.nn.Linear(span_start_input_dim, 1)) span_end_encoding_dim = span_end_encoder.get_output_dim() span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim self._span_end_predictor = TimeDistributed( torch.nn.Linear(span_end_input_dim, 1)) # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily # obvious from the configuration files, so we check here. check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim, "modeling layer input dim", "4 * encoding dim") check_dimensions_match(text_field_embedder.get_output_dim(), phrase_layer.get_input_dim(), "text field embedder output dim", "phrase layer input dim") check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim, "span end encoder input dim", "4 * encoding dim + 3 * modeling dim") self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms initializer(self) def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question tokens, passage tokens, original passage text, and token offsets into the passage for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer( self._text_field_embedder(question)) embedded_passage = self._highway_layer( self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax( passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector ], dim=-1) modeled_passage = self._dropout( self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout( torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze( 1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([ final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation ], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout( self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout( torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] output_dict['best_span_indices'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] output_dict['best_span_indices'].append( [start_offset, end_offset]) best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } @staticmethod def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor: # We call the inputs "logits" - they could either be unnormalized logits or normalized log # probabilities. A log_softmax operation is a constant shifting of the entire logit # vector, so taking an argmax over either one gives the same result. if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() device = span_start_logits.device # (batch_size, passage_length, passage_length) span_log_probs = span_start_logits.unsqueeze( 2) + span_end_logits.unsqueeze(1) # Only the upper triangle of the span matrix is valid; the lower triangle has entries where # the span ends before it starts. span_log_mask = torch.triu( torch.ones((passage_length, passage_length), device=device)).log().unsqueeze(0) valid_span_log_probs = span_log_probs + span_log_mask # Here we take the span matrix and flatten it, then find the best span using argmax. We # can recover the start and end indices from this flattened list using simple modular # arithmetic. # (batch_size, passage_length * passage_length) best_spans = valid_span_log_probs.view(batch_size, -1).argmax(-1) span_start_indices = best_spans // passage_length span_end_indices = best_spans % passage_length return torch.stack([span_start_indices, span_end_indices], dim=-1)
class EvidenceExtraction(Model): def __init__(self, vocab: Vocabulary, embedder: TextFieldEmbedder, question_encoder: Seq2SeqEncoder, passage_encoder: Seq2SeqEncoder, r: float = 0.8, dropout: float = 0.1, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(EvidenceExtraction, self).__init__(vocab, regularizer) self._embedder = embedder self._question_encoder = question_encoder self._passage_encoder = passage_encoder # size: 2H encoding_dim = question_encoder.get_output_dim() self._gru_cell = nn.GRUCell(2 * encoding_dim, encoding_dim) self._gate = nn.Linear(2 * encoding_dim, 2 * encoding_dim) self._match_layer_1 = nn.Linear(2 * encoding_dim, encoding_dim) self._match_layer_2 = nn.Linear(encoding_dim, 1) self._question_attention_for_passage = Attention( NonlinearSimilarity(encoding_dim)) self._question_attention_for_question = Attention( NonlinearSimilarity(encoding_dim)) self._passage_attention_for_answer = Attention( NonlinearSimilarity(encoding_dim), normalize=False) self._passage_attention_for_ranking = Attention( NonlinearSimilarity(encoding_dim)) self._passage_self_attention = Attention( NonlinearSimilarity(encoding_dim)) self._self_gru_cell = nn.GRUCell(2 * encoding_dim, encoding_dim) self._self_gate = nn.Linear(2 * encoding_dim, encoding_dim) self._answer_net = nn.GRUCell(encoding_dim, encoding_dim) self._v_r_Q = nn.Parameter(torch.rand(encoding_dim)) self._r = r self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x initializer(self) def forward( self, question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], #passages_length: torch.LongTensor = None, #correct_passage: torch.LongTensor = None, span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata=None) -> Dict[str, torch.Tensor]: # shape: B x Tq x E embedded_question = self._embedder(question) embedded_passage = self._embedder(passage) batch_size = embedded_question.size(0) total_passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question) passage_mask = util.get_text_field_mask(passage) # shape: B x T x 2H encoded_question = self._dropout( self._question_encoder(embedded_question, question_mask)) encoded_passage = self._dropout( self._passage_encoder(embedded_passage, passage_mask)) passage_mask = passage_mask.float() question_mask = question_mask.float() encoding_dim = encoded_question.size(-1) # shape: B x 2H if encoded_passage.is_cuda: cuda_device = encoded_passage.get_device() gru_hidden = Variable( torch.zeros(batch_size, encoding_dim).cuda(cuda_device)) else: gru_hidden = Variable(torch.zeros(batch_size, encoding_dim)) question_awared_passage = [] for timestep in range(total_passage_length): # shape: B x Tq = attention(B x 2H, B x Tq x 2H) attn_weights = self._question_attention_for_passage( encoded_passage[:, timestep, :], encoded_question, question_mask) # shape: B x 2H = weighted_sum(B x Tq x 2H, B x Tq) attended_question = util.weighted_sum(encoded_question, attn_weights) # shape: B x 4H passage_question_combined = torch.cat( [encoded_passage[:, timestep, :], attended_question], dim=-1) # shape: B x 4H gate = F.sigmoid(self._gate(passage_question_combined)) gru_input = gate * passage_question_combined # shape: B x 2H gru_hidden = self._dropout(self._gru_cell(gru_input, gru_hidden)) question_awared_passage.append(gru_hidden) # shape: B x T x 2H # question aware passage representation v_P question_awared_passage = torch.stack(question_awared_passage, dim=1) self_attended_passage = [] for timestep in range(total_passage_length): attn_weights = self._passage_self_attention( question_awared_passage[:, timestep, :], question_awared_passage, passage_mask) attended_passage = util.weighted_sum(question_awared_passage, attn_weights) input_combined = torch.cat( [question_awared_passage[:, timestep, :], attended_passage], dim=-1) gate = F.sigmoid(self._self_gate(input_combined)) gru_input = gate * input_combined gru_hidden = self._dropout(self._gru_cell(gru_input, gru_hidden)) self_attended_passage.append(gru_hidden) self_attended_passage = torch.stack(self_attended_passage, dim=1) # compute question vector r_Q # shape: B x T = attention(B x 2H, B x T x 2H) v_r_Q_tiled = self._v_r_Q.unsqueeze(0).expand(batch_size, encoding_dim) attn_weights = self._question_attention_for_question( v_r_Q_tiled, encoded_question, question_mask) # shape: B x 2H r_Q = util.weighted_sum(encoded_question, attn_weights) # shape: B x T = attention(B x 2H, B x T x 2H) span_start_logits = self._passage_attention_for_answer( r_Q, self_attended_passage, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) span_start_log_probs = util.masked_log_softmax(span_start_logits, passage_mask) # shape: B x 2H c_t = util.weighted_sum(self_attended_passage, span_start_probs) # shape: B x 2H h_1 = self._dropout(self._answer_net(c_t, r_Q)) span_end_logits = self._passage_attention_for_answer( h_1, self_attended_passage, passage_mask) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_end_log_probs = util.masked_log_softmax(span_end_logits, passage_mask) best_span = self.get_best_span(span_start_logits, span_end_logits) #num_passages = passages_length.size(1) #acc = Variable(torch.zeros(batch_size, num_passages + 1)).cuda(cuda_device).long() #acc[:, 1:num_passages+1] = torch.cumsum(passages_length, dim=1) #g_batch = [] #for b in range(batch_size): # g = [] # for i in range(num_passages): # if acc[b, i+1].data[0] > acc[b, i].data[0]: # attn_weights = self._passage_attention_for_ranking(r_Q[b:b+1], question_awared_passage[b:b+1, acc[b, i].data[0]: acc[b, i+1].data[0], :], passage_mask[b:b+1, acc[b, i].data[0]: acc[b, i+1].data[0]]) # r_P = util.weighted_sum(question_awared_passage[b:b+1, acc[b, i].data[0]:acc[b, i+1].data[0], :], attn_weights) # question_passage_combined = torch.cat([r_Q[b:b+1], r_P], dim=-1) # gi = self._dropout(self._match_layer_2(F.tanh(self._dropout(self._match_layer_1(question_passage_combined))))) # g.append(gi) # else: # g.append(Variable(torch.zeros(1, 1)).cuda(cuda_device)) # g = torch.cat(g, dim=1) # g_batch.append(g) #t2 = time.time() #g = torch.cat(g_batch, dim=0) output_dict = {} if span_start is not None: AP_loss = F.nll_loss(span_start_log_probs, span_start.squeeze(-1)) +\ F.nll_loss(span_end_log_probs, span_end.squeeze(-1)) #PR_loss = F.nll_loss(passage_log_probs, correct_passage.squeeze(-1)) #loss = self._r * AP_loss + self._r * PR_loss self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict['loss'] = AP_loss _, max_start = torch.max(span_start_probs, dim=1) _, max_end = torch.max(span_end_probs, dim=1) #t3 = time.time() output_dict['span_start_idx'] = max_start output_dict['span_end_idx'] = max_end #t4 = time.time() #global ITE #ITE += 1 #if (ITE % 100 == 0): # print(" gold %i:%i|predicted %i:%i" %(span_start.squeeze(-1)[0], span_end.squeeze(-1)[0], max_start.data[0], max_end.data[0])) if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].data.cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens #t5 = time.time() #print("Total: %.5f" % (t5-t0)) #print("Batch processing 1: %.5f" % (t2-t1)) #print("Batch processing 2: %.5f" % (t4-t3)) return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score } @staticmethod def get_best_span(span_start_logits: Variable, span_end_logits: Variable) -> Variable: if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = Variable(span_start_logits.data.new().resize_( batch_size, 2).fill_(0)).long() span_start_logits = span_start_logits.data.cpu().numpy() span_end_logits = span_end_logits.data.cpu().numpy() for b in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b, span_start_argmax[b]] if val1 < span_start_logits[b, j]: span_start_argmax[b] = j val1 = span_start_logits[b, j] val2 = span_end_logits[b, j] if val1 + val2 > max_span_log_prob[b]: best_word_span[b, 0] = span_start_argmax[b] best_word_span[b, 1] = j max_span_log_prob[b] = val1 + val2 return best_word_span @classmethod def from_params(cls, vocab: Vocabulary, params: Params) -> 'EvidenceExtraction': embedder_params = params.pop("text_field_embedder") embedder = TextFieldEmbedder.from_params(vocab, embedder_params) question_encoder = Seq2SeqEncoder.from_params( params.pop("question_encoder")) passage_encoder = Seq2SeqEncoder.from_params( params.pop("passage_encoder")) dropout = params.pop_float('dropout', 0.1) r = params.pop_float('r', 0.8) #cuda = params.pop_int('cuda', 0) initializer = InitializerApplicator.from_params( params.pop('initializer', [])) regularizer = RegularizerApplicator.from_params( params.pop('regularizer', [])) return cls( vocab=vocab, embedder=embedder, question_encoder=question_encoder, passage_encoder=passage_encoder, r=r, dropout=dropout, #cuda=cuda, initializer=initializer, regularizer=regularizer)
class DialogQA(Model): """ This class implements modified version of BiDAF (with self attention and residual layer, from Clark and Gardner ACL 17 paper) model as used in Question Answering in Context (EMNLP 2018) paper [https://arxiv.org/pdf/1808.07036.pdf]. In this set-up, a single instance is a dialog, list of question answer pairs. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. span_start_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span end predictions into the passage state. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). num_context_answers : ``int``, optional (default=0) If greater than 0, the model will consider previous question answering context. max_span_length: ``int``, optional (default=0) Maximum token length of the output span. max_turn_length: ``int``, optional (default=12) Maximum length of an interaction. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, phrase_layer: Seq2SeqEncoder, residual_encoder: Seq2SeqEncoder, span_start_encoder: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, initializer: InitializerApplicator, dropout: float = 0.2, num_context_answers: int = 0, marker_embedding_dim: int = 10, max_span_length: int = 30, max_turn_length: int = 12) -> None: super().__init__(vocab) self._num_context_answers = num_context_answers self._max_span_length = max_span_length self._text_field_embedder = text_field_embedder self._phrase_layer = phrase_layer self._marker_embedding_dim = marker_embedding_dim self._encoding_dim = phrase_layer.get_output_dim() self._matrix_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y') self._merge_atten = TimeDistributed(torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim)) self._residual_encoder = residual_encoder if num_context_answers > 0: self._question_num_marker = torch.nn.Embedding(max_turn_length, marker_embedding_dim * num_context_answers) self._prev_ans_marker = torch.nn.Embedding((num_context_answers * 4) + 1, marker_embedding_dim) self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y') self._followup_lin = torch.nn.Linear(self._encoding_dim, 3) self._merge_self_attention = TimeDistributed(torch.nn.Linear(self._encoding_dim * 3, self._encoding_dim)) self._span_start_encoder = span_start_encoder self._span_end_encoder = span_end_encoder self._span_start_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1)) self._span_end_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1)) self._span_yesno_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 3)) self._span_followup_predictor = TimeDistributed(self._followup_lin) check_dimensions_match(phrase_layer.get_input_dim(), text_field_embedder.get_output_dim() + marker_embedding_dim * num_context_answers, "phrase layer input dim", "embedding dim + marker dim * num context answers") initializer(self) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_yesno_accuracy = CategoricalAccuracy() self._span_followup_accuracy = CategoricalAccuracy() self._span_gt_yesno_accuracy = CategoricalAccuracy() self._span_gt_followup_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._official_f1 = Average() self._variational_dropout = InputVariationalDropout(dropout) def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, p1_answer_marker: torch.IntTensor = None, p2_answer_marker: torch.IntTensor = None, p3_answer_marker: torch.IntTensor = None, yesno_list: torch.IntTensor = None, followup_list: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. p1_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 0. This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length]. Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>. For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac p2_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 1. It is similar to p1_answer_marker, but marking previous previous answer in passage. p3_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 2. It is similar to p1_answer_marker, but marking previous previous previous answer in passage. yesno_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (the yes/no/not a yes no question). followup_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (followup / maybe followup / don't followup). metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of the followings. Each of the followings is a nested list because first iterates over dialog, then questions in dialog. qid : List[List[str]] A list of list, consisting of question ids. followup : List[List[int]] A list of list, consisting of continuation marker prediction index. (y :yes, m: maybe follow up, n: don't follow up) yesno : List[List[int]] A list of list, consisting of affirmation marker prediction index. (y :yes, x: not a yes/no question, n: np) best_span_str : List[List[str]] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size() total_qa_count = batch_size * max_qa_count qa_mask = torch.ge(followup_list, 0).view(total_qa_count) embedded_question = self._text_field_embedder(question, num_wrapping_dims=1) embedded_question = embedded_question.reshape(total_qa_count, max_q_len, self._text_field_embedder.get_output_dim()) embedded_question = self._variational_dropout(embedded_question) embedded_passage = self._variational_dropout(self._text_field_embedder(passage)) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float() question_mask = question_mask.reshape(total_qa_count, max_q_len) passage_mask = util.get_text_field_mask(passage).float() repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1) repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length) if self._num_context_answers > 0: # Encode question turn number inside the dialog into question embedding. question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question)) question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len) question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1) question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len) question_num_marker_emb = self._question_num_marker(question_num_ind) embedded_question = torch.cat([embedded_question, question_num_marker_emb], dim=-1) # Encode the previous answers in passage embedding. repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \ view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim()) # batch_size * max_qa_count, passage_length, word_embed_dim p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length) p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p1_answer_marker_emb], dim=-1) if self._num_context_answers > 1: p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length) p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p2_answer_marker_emb], dim=-1) if self._num_context_answers > 2: p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length) p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p3_answer_marker_emb], dim=-1) repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage, repeated_passage_mask)) else: encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask)) repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1) repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count, passage_length, self._encoding_dim) encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask)) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask) # Shape: (batch_size * max_qa_count, encoding_dim) question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count, passage_length, self._encoding_dim) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([repeated_encoded_passage, passage_question_vectors, repeated_encoded_passage * passage_question_vectors, repeated_encoded_passage * tiled_question_passage_vector], dim=-1) final_merged_passage = F.relu(self._merge_atten(final_merged_passage)) residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage, repeated_passage_mask)) self_attention_matrix = self._self_attention(residual_layer, residual_layer) mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \ * repeated_passage_mask.reshape(total_qa_count, 1, passage_length) self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device) self_mask = self_mask.reshape(1, passage_length, passage_length) mask = mask * (1 - self_mask) self_attention_probs = util.masked_softmax(self_attention_matrix, mask) # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim) self_attention_vecs = torch.matmul(self_attention_probs, residual_layer) self_attention_vecs = torch.cat([self_attention_vecs, residual_layer, residual_layer * self_attention_vecs], dim=-1) residual_layer = F.relu(self._merge_self_attention(self_attention_vecs)) final_merged_passage = final_merged_passage + residual_layer # batch_size * maxqa_pair_len * max_passage_len * 200 final_merged_passage = self._variational_dropout(final_merged_passage) start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask) span_start_logits = self._span_start_predictor(start_rep).squeeze(-1) end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1), repeated_passage_mask) span_end_logits = self._span_end_predictor(end_rep).squeeze(-1) span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1) span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1) span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7) # batch_size * maxqa_len_pair, max_document_len span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7) best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits, span_yesno_logits, span_followup_logits, self._max_span_length) output_dict: Dict[str, Any] = {} # Compute the loss. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1), ignore_index=-1) self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask) loss += nll_loss(util.masked_log_softmax(span_end_logits, repeated_passage_mask), span_end.view(-1), ignore_index=-1) self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask) self._span_accuracy(best_span[:, 0:2], torch.stack([span_start, span_end], -1).view(total_qa_count, 2), mask=qa_mask.unsqueeze(1).expand(-1, 2).long()) # add a select for the right span to compute loss gold_span_end_loc = [] span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy() for i in range(0, total_qa_count): gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0)) gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0)) gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0)) gold_span_end_loc = span_start.new(gold_span_end_loc) pred_span_end_loc = [] for i in range(0, total_qa_count): pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0)) pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0)) pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0)) predicted_end = span_start.new(pred_span_end_loc) _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3) _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3) loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1) loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1) _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3) _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3) self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask) self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask) output_dict["loss"] = loss # Compute F1 and preparing the output dictionary. output_dict['best_span_str'] = [] output_dict['qid'] = [] output_dict['followup'] = [] output_dict['yesno'] = [] best_span_cpu = best_span.detach().cpu().numpy() for i in range(batch_size): passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] f1_score = 0.0 per_dialog_best_span_list = [] per_dialog_yesno_list = [] per_dialog_followup_list = [] per_dialog_query_id_list = [] for per_dialog_query_index, (iid, answer_texts) in enumerate( zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])): predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index]) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] yesno_pred = predicted_span[2] followup_pred = predicted_span[3] per_dialog_yesno_list.append(yesno_pred) per_dialog_followup_list.append(followup_pred) per_dialog_query_id_list.append(iid) best_span_string = passage_str[start_offset:end_offset] per_dialog_best_span_list.append(best_span_string) if answer_texts: if len(answer_texts) > 1: t_f1 = [] # Compute F1 over N-1 human references and averages the scores. for answer_index in range(len(answer_texts)): idxes = list(range(len(answer_texts))) idxes.pop(answer_index) refs = [answer_texts[z] for z in idxes] t_f1.append(squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, refs)) f1_score = 1.0 * sum(t_f1) / len(t_f1) else: f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, answer_texts) self._official_f1(100 * f1_score) output_dict['qid'].append(per_dialog_query_id_list) output_dict['best_span_str'].append(per_dialog_best_span_list) output_dict['yesno'].append(per_dialog_yesno_list) output_dict['followup'].append(per_dialog_followup_list) return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: yesno_tags = [[self.vocab.get_token_from_index(x, namespace="yesno_labels") for x in yn_list] \ for yn_list in output_dict.pop("yesno")] followup_tags = [[self.vocab.get_token_from_index(x, namespace="followup_labels") for x in followup_list] \ for followup_list in output_dict.pop("followup")] output_dict['yesno'] = yesno_tags output_dict['followup'] = followup_tags return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: return {'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'yesno': self._span_yesno_accuracy.get_metric(reset), 'followup': self._span_followup_accuracy.get_metric(reset), 'f1': self._official_f1.get_metric(reset), } @staticmethod def _get_best_span_yesno_followup(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor, span_yesno_logits: torch.Tensor, span_followup_logits: torch.Tensor, max_span_length: int) -> torch.Tensor: # Returns the index of highest-scoring span that is not longer than 30 tokens, as well as # yesno prediction bit and followup prediction bit from the predicted span end token. if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError("Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 4), dtype=torch.long) span_start_logits = span_start_logits.data.cpu().numpy() span_end_logits = span_end_logits.data.cpu().numpy() span_yesno_logits = span_yesno_logits.data.cpu().numpy() span_followup_logits = span_followup_logits.data.cpu().numpy() for b_i in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b_i, span_start_argmax[b_i]] if val1 < span_start_logits[b_i, j]: span_start_argmax[b_i] = j val1 = span_start_logits[b_i, j] val2 = span_end_logits[b_i, j] if val1 + val2 > max_span_log_prob[b_i]: if j - span_start_argmax[b_i] > max_span_length: continue best_word_span[b_i, 0] = span_start_argmax[b_i] best_word_span[b_i, 1] = j max_span_log_prob[b_i] = val1 + val2 for b_i in range(batch_size): j = best_word_span[b_i, 1] yesno_pred = np.argmax(span_yesno_logits[b_i, j]) followup_pred = np.argmax(span_followup_logits[b_i, j]) best_word_span[b_i, 2] = int(yesno_pred) best_word_span[b_i, 3] = int(followup_pred) return best_word_span
class BidirectionalAttentionFlow(Model): def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, char_field_embedder: TextFieldEmbedder, # num_highway_layers: int, phrase_layer: Seq2SeqEncoder, char_rnn: Seq2SeqEncoder, hops: int, hidden_dim: int, dropout: float = 0.2, mask_lstms: bool = True, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._char_field_embedder = char_field_embedder self._features_embedder = nn.Embedding(2, 5) # self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim() + 5 * 3, # num_highway_layers)) self._phrase_layer = phrase_layer self._encoding_dim = phrase_layer.get_output_dim() # self._stacked_brnn = PytorchSeq2SeqWrapper( # StackedBidirectionalLstm(input_size=self._encoding_dim, hidden_size=hidden_dim, # num_layers=3, recurrent_dropout_probability=0.2)) self._char_rnn = char_rnn self.hops = hops self.interactive_aligners = nn.ModuleList() self.interactive_SFUs = nn.ModuleList() self.self_aligners = nn.ModuleList() self.self_SFUs = nn.ModuleList() self.aggregate_rnns = nn.ModuleList() for i in range(hops): # interactive aligner self.interactive_aligners.append( layers.SeqAttnMatch(self._encoding_dim)) self.interactive_SFUs.append( layers.SFU(self._encoding_dim, 3 * self._encoding_dim)) # self aligner self.self_aligners.append(layers.SelfAttnMatch(self._encoding_dim)) self.self_SFUs.append( layers.SFU(self._encoding_dim, 3 * self._encoding_dim)) # aggregating self.aggregate_rnns.append( PytorchSeq2SeqWrapper( nn.LSTM(input_size=self._encoding_dim, hidden_size=hidden_dim, num_layers=1, dropout=0.2, bidirectional=True, batch_first=True))) # Memmory-based Answer Pointer self.mem_ans_ptr = layers.MemoryAnsPointer(x_size=self._encoding_dim, y_size=self._encoding_dim, hidden_size=hidden_dim, hop=hops, dropout_rate=0.2, normalize=True) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_yesno_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms initializer(self) def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, yesno: torch.IntTensor = None, question_tf: torch.FloatTensor = None, passage_tf: torch.FloatTensor = None, q_em_cased: torch.IntTensor = None, p_em_cased: torch.IntTensor = None, q_em_uncased: torch.IntTensor = None, p_em_uncased: torch.IntTensor = None, q_in_lemma: torch.IntTensor = None, p_in_lemma: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ x1_c_emb = self._dropout(self._char_field_embedder(passage)) x2_c_emb = self._dropout(self._char_field_embedder(question)) # embedded_question = torch.cat([self._dropout(self._text_field_embedder(question)), # self._features_embedder(q_em_cased), # self._features_embedder(q_em_uncased), # self._features_embedder(q_in_lemma), # question_tf.unsqueeze(2)], dim=2) # embedded_passage = torch.cat([self._dropout(self._text_field_embedder(passage)), # self._features_embedder(p_em_cased), # self._features_embedder(p_em_uncased), # self._features_embedder(p_in_lemma), # passage_tf.unsqueeze(2)], dim=2) token_emb_q = self._dropout(self._text_field_embedder(question)) token_emb_c = self._dropout(self._text_field_embedder(passage)) token_emb_question, q_ner_and_pos = torch.split(token_emb_q, [300, 40], dim=2) token_emb_passage, p_ner_and_pos = torch.split(token_emb_c, [300, 40], dim=2) question_word_features = torch.cat([ q_ner_and_pos, self._features_embedder(q_em_cased), self._features_embedder(q_em_uncased), self._features_embedder(q_in_lemma), question_tf.unsqueeze(2) ], dim=2) passage_word_features = torch.cat([ p_ner_and_pos, self._features_embedder(p_em_cased), self._features_embedder(p_em_uncased), self._features_embedder(p_in_lemma), passage_tf.unsqueeze(2) ], dim=2) # embedded_question = self._highway_layer(embedded_q) # embedded_passage = self._highway_layer(embedded_q) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None char_features_c = self._char_rnn( x1_c_emb.reshape((x1_c_emb.size(0) * x1_c_emb.size(1), x1_c_emb.size(2), x1_c_emb.size(3))), passage_lstm_mask.unsqueeze(2).repeat( 1, 1, x1_c_emb.size(2)).reshape( (x1_c_emb.size(0) * x1_c_emb.size(1), x1_c_emb.size(2)))).reshape( (x1_c_emb.size(0), x1_c_emb.size(1), x1_c_emb.size(2), -1))[:, :, -1, :] char_features_q = self._char_rnn( x2_c_emb.reshape((x2_c_emb.size(0) * x2_c_emb.size(1), x2_c_emb.size(2), x2_c_emb.size(3))), question_lstm_mask.unsqueeze(2).repeat( 1, 1, x2_c_emb.size(2)).reshape( (x2_c_emb.size(0) * x2_c_emb.size(1), x2_c_emb.size(2)))).reshape( (x2_c_emb.size(0), x2_c_emb.size(1), x2_c_emb.size(2), -1))[:, :, -1, :] # token_emb_q, char_emb_q, question_word_features = torch.split(embedded_question, [300, 300, 56], dim=2) # token_emb_c, char_emb_c, passage_word_features = torch.split(embedded_passage, [300, 300, 56], dim=2) # char_features_q = self._char_rnn(char_emb_q, question_lstm_mask) # char_features_c = self._char_rnn(char_emb_c, passage_lstm_mask) emb_question = torch.cat( [token_emb_question, char_features_q, question_word_features], dim=2) emb_passage = torch.cat( [token_emb_passage, char_features_c, passage_word_features], dim=2) encoded_question = self._dropout( self._phrase_layer(emb_question, question_lstm_mask)) encoded_passage = self._dropout( self._phrase_layer(emb_passage, passage_lstm_mask)) batch_size = encoded_question.size(0) passage_length = encoded_passage.size(1) encoding_dim = encoded_question.size(-1) # c_check = self._stacked_brnn(encoded_passage, passage_lstm_mask) # q = self._stacked_brnn(encoded_question, question_lstm_mask) c_check = encoded_passage q = encoded_question for i in range(self.hops): q_tilde = self.interactive_aligners[i].forward( c_check, q, question_mask) c_bar = self.interactive_SFUs[i].forward( c_check, torch.cat([q_tilde, c_check * q_tilde, c_check - q_tilde], 2)) c_tilde = self.self_aligners[i].forward(c_bar, passage_mask) c_hat = self.self_SFUs[i].forward( c_bar, torch.cat([c_tilde, c_bar * c_tilde, c_bar - c_tilde], 2)) c_check = self.aggregate_rnns[i].forward(c_hat, passage_mask) # Predict start_scores, end_scores, yesno_scores = self.mem_ans_ptr.forward( c_check, q, passage_mask, question_mask) best_span, yesno_predict, loc = self.get_best_span( start_scores, end_scores, yesno_scores) output_dict = { "span_start_logits": start_scores, "span_end_logits": end_scores, "best_span": best_span } # Compute the loss for training. if span_start is not None: loss = nll_loss(start_scores, span_start.squeeze(-1)) self._span_start_accuracy(start_scores, span_start.squeeze(-1)) loss += nll_loss(end_scores, span_end.squeeze(-1)) self._span_end_accuracy(end_scores, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) gold_span_end_loc = [] span_end = span_end.view(batch_size).squeeze().data.cpu().numpy() for i in range(batch_size): gold_span_end_loc.append( max(span_end[i] + i * passage_length, 0)) gold_span_end_loc = span_start.new(gold_span_end_loc) _yesno = yesno_scores.view(-1, 3).index_select( 0, gold_span_end_loc).view(-1, 3) loss += nll_loss(_yesno, yesno.view(-1), ignore_index=-1) pred_span_end_loc = [] for i in range(batch_size): pred_span_end_loc.append(max(loc[i], 0)) predicted_end = span_start.new(pred_span_end_loc) _yesno = yesno_scores.view(-1, 3).index_select(0, predicted_end).view( -1, 3) self._span_yesno_accuracy(_yesno, yesno.squeeze(-1)) output_dict['loss'] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens output_dict['yesno'] = yesno_predict return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), "yesno": self._span_yesno_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: yesno_tags = [ self.vocab.get_token_from_index(x, namespace="yesno_labels") for x in output_dict.pop("yesno") ] output_dict['yesno'] = yesno_tags return output_dict @staticmethod def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor, yesno_scores: torch.Tensor): if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long) yesno_predict = span_start_logits.new_zeros(batch_size, dtype=torch.long) loc = yesno_scores.new_zeros(batch_size, dtype=torch.long) span_start_logits = span_start_logits.detach().cpu().numpy() span_end_logits = span_end_logits.detach().cpu().numpy() yesno_logits = yesno_scores.detach().cpu().numpy() for b in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b, span_start_argmax[b]] if val1 < span_start_logits[b, j]: span_start_argmax[b] = j val1 = span_start_logits[b, j] val2 = span_end_logits[b, j] if val1 + val2 > max_span_log_prob[b]: best_word_span[b, 0] = span_start_argmax[b] best_word_span[b, 1] = j max_span_log_prob[b] = val1 + val2 yesno_predict[b] = int(np.argmax(yesno_logits[b, j])) loc[b] = j + passage_length * b return best_word_span, yesno_predict, loc
class BidirectionalAttentionFlow(Model): """ This class implements Minjoon Seo's `Bidirectional Attention Flow model <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_ for answering reading comprehension questions (ICLR 2017). The basic layout is pretty simple: encode words as a combination of word embeddings and a character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of attentions to put question information into the passage word representations (this is the only part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and do a softmax over span start and span end. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. num_highway_layers : ``int`` The number of highway layers to use in between embedding the input and passing it through the phrase layer. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. similarity_function : ``SimilarityFunction`` The similarity function that we will use when comparing encoded passage and question representations. modeling_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). mask_lstms : ``bool``, optional (default=True) If ``False``, we will skip passing the mask to the LSTM layers. This gives a ~2x speedup, with only a slight performance decrease, if any. We haven't experimented much with this yet, but have confirmed that we still get very similar performance with much faster training times. We still use the mask for all softmaxes, but avoid the shuffling that's required when using masking with pytorch LSTMs. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, similarity_function: SimilarityFunction, modeling_layer: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, dropout: float = 0.2, mask_lstms: bool = True, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim(), num_highway_layers)) self._phrase_layer = phrase_layer self._matrix_attention = LegacyMatrixAttention(similarity_function) self._modeling_layer = modeling_layer self._span_end_encoder = span_end_encoder encoding_dim = phrase_layer.get_output_dim() modeling_dim = modeling_layer.get_output_dim() span_start_input_dim = encoding_dim * 4 + modeling_dim self._span_start_predictor = TimeDistributed(torch.nn.Linear(span_start_input_dim, 1)) span_end_encoding_dim = span_end_encoder.get_output_dim() span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim self._span_end_predictor = TimeDistributed(torch.nn.Linear(span_end_input_dim, 1)) # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily # obvious from the configuration files, so we check here. check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim, "modeling layer input dim", "4 * encoding dim") check_dimensions_match(text_field_embedder.get_output_dim(), phrase_layer.get_input_dim(), "text field embedder output dim", "phrase layer input dim") check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim, "span end encoder input dim", "4 * encoding dim + 3 * modeling dim") self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms initializer(self) def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer(self._text_field_embedder(question)) embedded_passage = self._highway_layer(self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector], dim=-1) modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } @staticmethod def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor: if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError("Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long) span_start_logits = span_start_logits.detach().cpu().numpy() span_end_logits = span_end_logits.detach().cpu().numpy() for b in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b, span_start_argmax[b]] if val1 < span_start_logits[b, j]: span_start_argmax[b] = j val1 = span_start_logits[b, j] val2 = span_end_logits[b, j] if val1 + val2 > max_span_log_prob[b]: best_word_span[b, 0] = span_start_argmax[b] best_word_span[b, 1] = j max_span_log_prob[b] = val1 + val2 return best_word_span
class BidafV2(Model): """ The modified version of official bidaf with support for squad v2 """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, metric: Metric, similarity_function: SimilarityFunction, modeling_layer: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, dropout: float = 0.2, mask_lstms: bool = True, no_answer: bool = False, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(BidafV2, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim(), num_highway_layers)) self._phrase_layer = phrase_layer self._matrix_attention = LegacyMatrixAttention(similarity_function) self._modeling_layer = modeling_layer self._span_end_encoder = span_end_encoder encoding_dim = phrase_layer.get_output_dim() modeling_dim = modeling_layer.get_output_dim() span_start_input_dim = encoding_dim * 4 + modeling_dim self._span_start_predictor = TimeDistributed(torch.nn.Linear(span_start_input_dim, 1)) span_end_encoding_dim = span_end_encoder.get_output_dim() span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim self._span_end_predictor = TimeDistributed(torch.nn.Linear(span_end_input_dim, 1)) # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily # obvious from the configuration files, so we check here. check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim, "modeling layer input dim", "4 * encoding dim") check_dimensions_match(text_field_embedder.get_output_dim(), phrase_layer.get_input_dim(), "text field embedder output dim", "phrase layer input dim") check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim, "span end encoder input dim", "4 * encoding dim + 3 * modeling dim") self._span_accuracy = BooleanAccuracy() self._squad_metrics = metric if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms self._threshold = torch.nn.Parameter(torch.zeros(1, 1)) self._no_answer = no_answer initializer(self) def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer(self._text_field_embedder(question)) embedded_passage = self._highway_layer(self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector], dim=-1) modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) # Add no answer padding. if self._no_answer: # Shape: (batch_size, passage_length + 1) passage_eval_mask = torch.cat([passage_mask, passage_mask.new_ones((batch_size, 1))], dim=-1) # Shape: (batch_size, 1) threshold = self._threshold.expand(batch_size, 1) # Shape: (batch_size, passage_length + 1) span_start_logits = torch.cat([span_start_logits, threshold], dim=-1) span_end_logits = torch.cat([span_end_logits, threshold], dim=-1) else: passage_eval_mask = passage_mask span_start_logits = util.replace_masked_values(span_start_logits, passage_eval_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_eval_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits, self._no_answer) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None and span_end is not None: self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) # In case there is no answer, convert span_start and span_end from -1 to passage_length if self._no_answer: span_start = torch.tensor(span_start)# pylint: disable=not-callable span_end = torch.tensor(span_end)# pylint: disable=not-callable for i in range(batch_size): if span_start[i][0] == -1: span_start[i][0] = passage_length span_end[i][0] = passage_length loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_eval_mask), span_start.squeeze(-1)) loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_eval_mask), span_end.squeeze(-1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) if predicted_span[0] < 0: best_span_string = '' else: start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: """ Output metrics include em, f1, no_em, span_acc """ ret: Dict[str, Any] = {} if self._no_answer: ret = self._squad_metrics.get_metric(reset) else: exact_match, f1_score = self._squad_metrics.get_metric(reset) ret['em'] = exact_match ret['f1'] = f1_score ret['span_acc'] = self._span_accuracy.get_metric(reset) return ret @staticmethod def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor, no_answer: bool = False) -> torch.Tensor: """ Output best span (st, ed) where span_start_logits[st] + span_end_logits[ed] (st<=ed) is maximized if no_answer set to True, span_start_logits[-1] + span_end_logits[-1] will be checked seprately, if this value is max, return (-1, -1) """ if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError("Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() if no_answer: passage_length = passage_length - 1 max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long) span_start_logits = span_start_logits.detach().cpu().numpy() span_end_logits = span_end_logits.detach().cpu().numpy() for b in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b, span_start_argmax[b]] if val1 < span_start_logits[b, j]: span_start_argmax[b] = j val1 = span_start_logits[b, j] val2 = span_end_logits[b, j] if val1 + val2 > max_span_log_prob[b]: best_word_span[b, 0] = span_start_argmax[b] best_word_span[b, 1] = j max_span_log_prob[b] = val1 + val2 if no_answer and max_span_log_prob[b] < span_start_logits[b, -1] + span_end_logits[b, -1]: best_word_span[b, 0] = -1 best_word_span[b, 1] = -1 max_span_log_prob[b] = span_start_logits[b, -1] + span_end_logits[b, -1] return best_word_span