def _compute_span_representations( self, text_embeddings: torch.FloatTensor, text_mask: torch.FloatTensor, span_starts: torch.IntTensor, span_ends: torch.IntTensor) -> torch.FloatTensor: """ Computes an embedded representation of every candidate span. This is a concatenation of the contextualized endpoints of the span, an embedded representation of the width of the span and a representation of the span's predicted head. Parameters ---------- text_embeddings : ``torch.FloatTensor``, required. The embedded document of shape (batch_size, document_length, embedding_dim) over which we are computing a weighted sum. text_mask : ``torch.FloatTensor``, required. A mask of shape (batch_size, document_length) representing non-padding entries of ``text_embeddings``. span_starts : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans) representing the start of each span candidate. span_ends : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans) representing the end of each span candidate. Returns ------- span_embeddings : ``torch.FloatTensor`` An embedded representation of every candidate span with shape: (batch_size, num_spans, context_layer.get_output_dim() * 2 + embedding_size + feature_size) """ # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer( text_embeddings, text_mask) # Shape: (batch_size, num_spans, encoding_dim) start_embeddings = util.batched_index_select(contextualized_embeddings, span_starts.squeeze(-1)) end_embeddings = util.batched_index_select(contextualized_embeddings, span_ends.squeeze(-1)) # Compute and embed the span_widths (strictly speaking the span_widths - 1) # Shape: (batch_size, num_spans, 1) span_widths = span_ends - span_starts # Shape: (batch_size, num_spans, encoding_dim) span_width_embeddings = self._span_width_embedding( span_widths.squeeze(-1)) # Shape: (batch_size, document_length, 1) head_scores = self._head_scorer(contextualized_embeddings) # Shape: (batch_size, num_spans, embedding_dim) # Note that we used the original text embeddings, not the contextual ones here. attended_text_embeddings = self._create_attended_span_representations( head_scores, text_embeddings, span_ends, span_widths) # (batch_size, num_spans, context_layer.get_output_dim() * 2 + embedding_dim + feature_dim) span_embeddings = torch.cat([ start_embeddings, end_embeddings, span_width_embeddings, attended_text_embeddings ], -1) return span_embeddings
def forward( self, text: Dict[str, torch.LongTensor], label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: """ Parameters ---------- text : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional (default = None) From a ``LabelField`` or ``MultiLabelField``, a tensor of shape ``(batch_size, num_labels)``. metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata containing the original tokenization of the premise and hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively. Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_text = self.text_field_embedder(text) mask = util.get_text_field_mask(text) encoded_text = self.text_encoder(embedded_text, mask) pooled = self.pool(encoded_text, mask) hidden = self.classifier_feedforward( pooled) # batch size x hidden size logits = self.prediction_layer(hidden) # batch size x num labels # Reference: https://pytorch.org/docs/master/nn.html#sigmoid probabilities = torch.sigmoid(logits) # batch size x num labels output_dict = {"logits": logits, "class_probs": probabilities} if label is not None: predictions = (logits.data > 0.0).long() label_data = label.squeeze(-1).data.long() self.micro_f1(predictions, label_data) output_dict["loss"] = self.loss(logits.squeeze(), label.squeeze(-1).float()) for i, k in enumerate(self.label_f1.keys()): label_f1 = self.label_f1[k] label_f1(predictions[:, i], label[:, i].long()) return output_dict
def forward( self, Orgquestion: Dict[str, torch.LongTensor], Relquestion: Dict[str, torch.LongTensor], label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None ) -> Dict[str, torch.Tensor]: embedded_premise = self._seq_dropout( self._text_field_embedder( Orgquestion ) ) premise_mask = util.get_text_field_mask( Orgquestion ) encoded_premise = self._encoder( embedded_premise, premise_mask ) embedded_hypo = self._seq_dropout( self._text_field_embedder( Relquestion ) ) hypo_mask = util.get_text_field_mask( Relquestion ) encoded_hypo = self._encoder( embedded_hypo, hypo_mask ) # 使用Neural Tensor Network提取interaction similarity_matrix = self._matching_layer( encoded_premise, encoded_hypo ) pool_out = self._pool_layer( similarity_matrix ) # k-max pooling到固定维度 label_out = self._output_feedforward( pool_out ) label_logits=torch.sigmoid(label_out) output_dict = {"label_logits": label_logits} if label is not None: loss = self._loss(label_logits, label) for metric in self.metrics.values(): metric( label_logits.squeeze(1), label.squeeze(1) ) output_dict["loss"] = loss return output_dict
def forward(self, features: torch.Tensor, metadata: List[Dict[str, Any]] = None, label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: """ Parameters ---------- features: torch.Tensor, From a ``FloatField`` over the overlap features computed by the SimpleOverlapReader metadata: List[Dict[str, Any]] Metadata information label : torch.IntTensor, optional (default = None) From a ``LabelField`` Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ label_logits = self.linear_mlp(features) label_probs = torch.nn.functional.softmax(label_logits) output_dict = {"label_logits": label_logits, "label_probs": label_probs} if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label.squeeze(-1)) output_dict["loss"] = loss return output_dict
def forward( self, Orgquestion: Dict[str, torch.LongTensor], Relquestion: Dict[str, torch.LongTensor], label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: embedded_premise = self._seq_dropout( self._text_field_embedder(Orgquestion)) embedded_hypo = self._seq_dropout( self._text_field_embedder(Relquestion)) # 计算词向量的dot-product相似度 similarity_matrix = torch.unsqueeze(self._matching_layer( embedded_premise, embedded_hypo), dim=1) conv_out = self._inference_encoder( similarity_matrix) # 使用堆叠卷积层处理interaction pool_out = self._pool_layer(conv_out) # dynamic pooling到固定维度 label_logits = self._output_feedforward(pool_out) label_logits = torch.sigmoid(label_logits) output_dict = {"label_logits": label_logits} if label is not None: loss = self._loss(label_logits, label) for metric in self.metrics.values(): metric(label_logits.squeeze(1), label.squeeze(1)) output_dict["loss"] = loss return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], choices_list: Dict[str, torch.LongTensor], label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField`` choices_list : Dict[str, torch.LongTensor] From a ``List[TextField]`` label : torch.IntTensor, optional (default = None) From a ``LabelField`` Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of each choice being the correct answer. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of each choice being the correct answer. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ encoded_choices_aggregated = embed_encode_and_aggregate_list_text_field( choices_list, self._text_field_embedder, self._embeddings_dropout, self._choice_encoder, self._choice_aggregate) # bs, choices, hs encoded_question_aggregated = embed_encode_and_aggregate_text_field( question, self._text_field_embedder, self._embeddings_dropout, self._question_encoder, self._question_aggregate) # bs, hs q_to_choices_att = self._matrix_attention_question_to_choice( encoded_question_aggregated.unsqueeze(1), encoded_choices_aggregated).squeeze() label_logits = q_to_choices_att label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = { "label_logits": label_logits, "label_probs": label_probs } if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label.squeeze(-1)) output_dict["loss"] = loss return output_dict
def forward( self, # type: ignore hypothesis0: Dict[str, torch.LongTensor], hypothesis1: Dict[str, torch.LongTensor], hypothesis2: Dict[str, torch.LongTensor], hypothesis3: Dict[str, torch.LongTensor], label: torch.IntTensor = None, ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- Returns ------- An output dictionary consisting of: logits : torch.FloatTensor A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing unnormalised log probabilities of the tag classes. class_probabilities : torch.FloatTensor A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing a distribution of the tag classes per word. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ logits = [] for tokens in [hypothesis0, hypothesis1, hypothesis2, hypothesis3]: if isinstance(self.text_field_embedder, ElmoTokenEmbedder): self.text_field_embedder._elmo._elmo_lstm._elmo_lstm.reset_states( ) embedded_text_input = self.embedding_dropout( self.text_field_embedder(tokens)) mask = get_text_field_mask(tokens) batch_size, sequence_length, _ = embedded_text_input.size() encoded_text = self.encoder(embedded_text_input, mask) logits.append(self.output_prediction(encoded_text.max(1)[0])) logits = torch.cat(logits, -1) class_probabilities = F.softmax(logits, dim=-1).view([batch_size, 4]) output_dict = { "label_logits": logits, "label_probs": class_probabilities } if label is not None: loss = self._loss(logits, label.long().view(-1)) self._accuracy(logits, label.squeeze(-1)) output_dict["loss"] = loss return output_dict
def forward( self, # type: ignore qa_pairs: Dict[str, torch.LongTensor], answer_index: torch.IntTensor = None, metadata: List[Dict[str, Any] # pylint:disable=unused-argument ] = None ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- qa_pairs : Dict[str, torch.LongTensor] From a ``TextField`` (that has a bert-pretrained token indexer) answer_index : torch.IntTensor, optional (default = None) From a ``LabelField`` Returns ------- An output dictionary consisting of: logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalized log probabilities of the label. probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ nsp_logits = self.bert(qa_pairs, num_wrapping_dims=1) # nsp_logits: batch, 5, 2 nsp_probs = torch.nn.functional.softmax(nsp_logits, dim=-1) # nsp_probs: batch, 5, 2 nsp_pos_probs = nsp_probs[..., 0] # nsp_pos_probs = self._linear(nsp_probs).squeeze(-1) label_probs = torch.nn.functional.softmax(nsp_pos_probs, dim=-1) # label_score: batch, 5 output_dict = { "nsp_logits": nsp_logits, "nsp_probs": nsp_probs, "label_probs": label_probs } if answer_index is not None: answer_index = answer_index.squeeze(-1) loss = self._loss(nsp_pos_probs, answer_index) self._accuracy(nsp_pos_probs, answer_index) output_dict["loss"] = loss return output_dict
def forward( self, # type: ignore tokens: Dict[str, torch.LongTensor], label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens: Dict[str, torch.LongTensor] From a ``TextField``. label: ``torch.IntTensor``, optional From an ``LabelField``. This is what we are trying to predict. If this is given, we will compute a loss that gets included in the output dictionary. Returns ------- An output dictionary consisting of the followings. logits: torch.FloatTensor A tensor of shape ``(batch_size, 2)`` representing unnormalised log probabilities of the evidence selection confidence. probs: torch.FloatTensor A tensor of shape ``(batch_size, 2)`` representing probabilities of the evidence selection confidence. loss: torch.FloatTensor, optional A scalar loss to be optimised. """ # batch, seq_len cls_hidden = self._bert(tokens) # batch, seq_len, hidden cls_hidden = cls_hidden[:, 0, :] cls_hidden = self._pooler(cls_hidden) # batch, hidden if self.dropout: cls_hidden = self.dropout(cls_hidden) # the final MLP -- apply dropout to input, and MLP applies to hidden logits = self._classifier(cls_hidden) probs = torch.nn.functional.softmax(logits, dim=-1) output_dict = {"logits": logits, "probs": probs} if label is not None: label = label.squeeze(-1) loss = self._loss(logits, label) self._accuracy(logits, label) output_dict["loss"] = loss return output_dict
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question tokens, passage tokens, original passage text, and token offsets into the passage for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() embedded_question = self._dropout(self._text_field_embedder(question)) embedded_passage = self._dropout(self._text_field_embedder(passage)) embedded_question = self._highway_layer(self._embedding_proj_layer(embedded_question)) embedded_passage = self._highway_layer(self._embedding_proj_layer(embedded_passage)) batch_size = embedded_question.size(0) projected_embedded_question = self._encoding_proj_layer(embedded_question) projected_embedded_passage = self._encoding_proj_layer(embedded_passage) encoded_question = self._dropout(self._phrase_layer(projected_embedded_question, question_mask)) encoded_passage = self._dropout(self._phrase_layer(projected_embedded_passage, passage_mask)) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = masked_softmax( passage_question_similarity, question_mask, memory_efficient=True) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # Shape: (batch_size, question_length, passage_length) question_passage_attention = masked_softmax( passage_question_similarity.transpose(1, 2), passage_mask, memory_efficient=True) # Shape: (batch_size, passage_length, passage_length) attention_over_attention = torch.bmm(passage_question_attention, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) passage_passage_vectors = util.weighted_sum(encoded_passage, attention_over_attention) # Shape: (batch_size, passage_length, encoding_dim * 4) merged_passage_attention_vectors = self._dropout( torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * passage_passage_vectors], dim=-1) ) modeled_passage_list = [self._modeling_proj_layer(merged_passage_attention_vectors)] for _ in range(3): modeled_passage = self._dropout(self._modeling_layer(modeled_passage_list[-1], passage_mask)) modeled_passage_list.append(modeled_passage) # Shape: (batch_size, passage_length, modeling_dim * 2)) span_start_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length, modeling_dim * 2) span_end_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e32) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e32) # Shape: (batch_size, passage_length) span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1) span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1) best_span = get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.cat([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
def forward( self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, metadata: List[Dict[str, Any]], doc_span_offsets: torch.IntTensor, span_labels: torch.IntTensor = None, doc_truth_spans: torch.IntTensor = None, doc_spans_in_truth: torch.IntTensor = None, doc_relation_labels: torch.Tensor = None, truth_spans: List[Set[Tuple[int, int]]] = None, doc_relations=None, doc_ner_labels: torch.IntTensor = None, ) -> Dict[str, torch.Tensor]: # add matrix from datareader # pylint: disable=arguments-differ """ Parameters ---------- text : ``Dict[str, torch.LongTensor]``, required. The output of a ``TextField`` representing the text of the document. spans : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. span_labels : ``torch.IntTensor``, optional (default = None) A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. metadata : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. doc_ner_labels : ``torch.IntTensor``. A tensor of shape # TODO, ... doc_span_offsets : ``torch.IntTensor``. A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1), ... doc_truth_spans : ``torch.IntTensor``. A tensor of shape (batch_size, max_sentences, max_truth_spans, 1), ... doc_spans_in_truth : ``torch.IntTensor``. A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1), ... doc_relation_labels : ``torch.Tensor``. A tensor of shape (batch_size, max_sentences, max_truth_spans, max_truth_spans), ... Returns ------- An output dictionary consisting of: top_spans : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : ``torch.IntTensor`` A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout( self._text_field_embedder(text)) batch_size = len(spans) document_length = text_embeddings.size(1) max_sentence_length = max( len(sentence) for document in metadata for sentence in document['doc_tokens']) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer( text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor( contextualized_embeddings, spans) # TODO features dropout # Shape: (batch_size, num_spans, embedding_size) attended_span_embeddings = self._attentive_span_extractor( text_embeddings, spans) # Shape: (batch_size, num_spans, embedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int( math.floor(self._spans_per_word * document_length)) num_relex_spans_to_keep = int( math.floor(self._relex_spans_per_word * max_sentence_length)) # Shapes: # (batch_size, num_spans_to_keep, span_dim), # (batch_size, num_spans_to_keep), # (batch_size, num_spans_to_keep), # (batch_size, num_spans_to_keep, 1) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner( span_embeddings, span_mask, num_spans_to_keep) # Shape: (batch_size, num_spans_to_keep, 1) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. Note that this is independent # of the batch dimension - it's just a function of the span's position in # top_spans. The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select( top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings( top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = dict() output_dict["top_spans"] = top_spans output_dict["antecedent_indices"] = valid_antecedent_indices output_dict["predicted_antecedents"] = predicted_antecedents if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] # Shape: (,) loss = 0 # Shape: (batch_size, max_sentences, max_spans) doc_span_mask = (doc_span_offsets[:, :, :, 0] >= 0).float() # Shape: (batch_size, max_sentences, num_spans, span_dim) doc_span_embeddings = util.batched_index_select( span_embeddings, doc_span_offsets.squeeze(-1).long().clamp(min=0)) # Shapes: # (batch_size, max_sentences, num_relex_spans_to_keep, span_dim), # (batch_size, max_sentences, num_relex_spans_to_keep), # (batch_size, max_sentences, num_relex_spans_to_keep), # (batch_size, max_sentences, num_relex_spans_to_keep, 1) pruned = self._relex_mention_pruner( doc_span_embeddings, doc_span_mask, num_items_to_keep=num_relex_spans_to_keep, pass_through=['num_items_to_keep']) (top_relex_span_embeddings, top_relex_span_mask, top_relex_span_indices, top_relex_span_mention_scores) = pruned # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1) top_relex_span_mask = top_relex_span_mask.unsqueeze(-1) # Shape: (batch_size, max_sentences, max_spans_per_sentence, 2) # TODO do we need for a mask? doc_spans = util.batched_index_select( spans, doc_span_offsets.clamp(0).squeeze(-1)) # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 2) top_relex_spans = nd_batched_index_select(doc_spans, top_relex_span_indices) # Shapes: # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, 3 * span_dim), # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep). (relex_span_pair_embeddings, relex_span_pair_mask) = self._compute_relex_span_pair_embeddings( top_relex_span_embeddings, top_relex_span_mask.squeeze(-1)) # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, num_relation_labels) relex_scores = self._compute_relex_scores( relex_span_pair_embeddings, top_relex_span_mention_scores) output_dict['relex_scores'] = relex_scores output_dict['top_relex_spans'] = top_relex_spans if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels_ = util.flattened_index_select( pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels = antecedent_labels_ + valid_antecedent_log_mask.long( ) # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability x to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax( coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log( ) negative_marginal_log_likelihood = -util.logsumexp( correct_antecedent_log_probs) negative_marginal_log_likelihood *= top_span_mask.squeeze( -1).float() negative_marginal_log_likelihood = negative_marginal_log_likelihood.sum( ) self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) coref_loss = negative_marginal_log_likelihood output_dict['coref_loss'] = coref_loss loss += self._loss_coref_weight * coref_loss if doc_relations is not None: # The adjacency matrix for relation extraction is very sparse. # As it is not just sparse, but row/column sparse (only few # rows and columns are non-zero and in that case these rows/columns # are not sparse), we implemented our own matrix for the case. # Here we have indices of truth spans and mapping, using which # we map prediction matrix on truth matrix. # TODO Add teacher forcing support. # Shape: (batch_size, max_sentences, num_relex_spans_to_keep), relative_indices = top_relex_span_indices # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1), compressed_indices = nd_batched_padded_index_select( doc_spans_in_truth, relative_indices) # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, max_truth_spans) gold_pruned_rows = nd_batched_padded_index_select( doc_relation_labels, compressed_indices.squeeze(-1), padding_value=0) gold_pruned_rows = gold_pruned_rows.permute(0, 1, 3, 2).contiguous() # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep) gold_pruned_matrices = nd_batched_padded_index_select( gold_pruned_rows, compressed_indices.squeeze(-1), padding_value=0) # pad with epsilon gold_pruned_matrices = gold_pruned_matrices.permute( 0, 1, 3, 2).contiguous() # TODO log_mask relex score before passing relex_loss = nd_cross_entropy_with_logits(relex_scores, gold_pruned_matrices, relex_span_pair_mask) output_dict['relex_loss'] = relex_loss self._relex_mention_recall(top_relex_spans.view(batch_size, -1, 2), truth_spans) self._compute_relex_metrics(output_dict, doc_relations) loss += self._loss_relex_weight * relex_loss if doc_ner_labels is not None: # Shape: (batch_size, max_sentences, num_spans, num_ner_classes) ner_scores = self._ner_scorer(doc_span_embeddings) output_dict['ner_scores'] = ner_scores ner_loss = nd_cross_entropy_with_logits(ner_scores, doc_ner_labels, doc_span_mask) output_dict['ner_loss'] = ner_loss loss += self._loss_ner_weight * ner_loss if not isinstance(loss, int): # If loss is not yet modified output_dict["loss"] = loss return output_dict
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer(self._text_field_embedder(question)) embedded_passage = self._highway_layer(self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector], dim=-1) modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
def forward( self, # type: ignore metadata: Dict, tokens: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : Dict[str, torch.LongTensor] From a ``TextField`` (that has a bert-pretrained token indexer) span_start : torch.IntTensor, optional (default = None) A tensor of shape (batch_size, 1) which contains the start_position of the answer in the passage, or 0 if impossible. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : torch.IntTensor, optional (default = None) A tensor of shape (batch_size, 1) which contains the end_position of the answer in the passage, or 0 if impossible. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. Returns ------- An output dictionary consisting of: logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalized log probabilities of the label. start_probs: torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the label. end_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the label. best_span: loss : torch.FloatTensor, optional A scalar loss to be optimised. """ input_ids = tokens[self._index] token_type_ids = tokens[f"{self._index}-type-ids"] input_mask = (input_ids != 0).long() # 1. Build model here # shape bert_embeddings: (batch_size, seq_len, 768) # shape bert_embeddings: (batch_size, 768) bert_embeddings, pooled_output = self.bert_model( input_ids, token_type_ids, attention_mask=input_mask) bert_embeddings = self.dropout(bert_embeddings) start_scores = self.start_linear(bert_embeddings) end_scores = self.end_linear(bert_embeddings) #shape: (batch, seq_len) start_logits = start_scores.squeeze() end_logits = end_scores.squeeze() # mask scores, so that only the context is considered # question mask: in token_type_ids the context has 1s and the question 0s. # To speed up training, replace 0s with negative infinity #question_mask = token_type_ids.clone().float().log() # question_mask = token_type_ids.clone().float() question_mask = (token_type_ids.float() - 1) * 1000000 + 1 start_logits = start_logits * question_mask end_logits = end_logits * question_mask start_probs = softmax(start_logits) end_probs = softmax(end_logits) output_dict = {} if span_start is not None: #start_loss = self.loss_function(start_logits, span_start.squeeze()) start_loss = self.loss_function(start_probs, span_start.squeeze()) #end_loss = self.loss_function(end_logits, span_end.squeeze()) end_loss = self.loss_function(end_probs, span_end.squeeze()) self._span_start_accuracy(start_logits, span_start.squeeze(-1)) self._span_end_accuracy(end_logits, span_end.squeeze(-1)) # 2. Compute start_position and end_position and then get the best span # using allennlp.models.reading_comprehension.util.get_best_span() loss = (start_loss + end_loss) / 2 # 4. Compute loss and accuracies. You should compute at least: # span_start accuracy, span_end accuracy and full span accuracy. # UNCOMMENT THIS LINE output_dict["loss"] = loss output_dict["_span_start_accuracy"] = self._span_start_accuracy output_dict["_span_end_accuracy"] = self._span_end_accuracy # 5. Optionally you can compute the official squad metrics (exact match, f1). # Instantiate the metric object in __init__ using allennlp.training.metrics.SquadEmAndF1() # When you call it, you need to give it the word tokens of the span (implement and call decode() below) # and the gold tokens found in metadata[i]['answer_texts'] return output_dict
def forward( self, # type: ignore premise: Dict[str, torch.LongTensor], hypothesis: Dict[str, torch.LongTensor], label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- premise : Dict[str, torch.LongTensor] From a ``TextField`` hypothesis : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional (default = None) From a ``LabelField`` Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_premise = self._text_field_embedder(premise) embedded_hypothesis = self._text_field_embedder(hypothesis) premise_mask = get_text_field_mask(premise).float() hypothesis_mask = get_text_field_mask(hypothesis).float() if self._premise_encoder: embedded_premise = self._premise_encoder(embedded_premise, premise_mask) if self._hypothesis_encoder: embedded_hypothesis = self._hypothesis_encoder( embedded_hypothesis, hypothesis_mask) projected_premise = self._attend_feedforward(embedded_premise) projected_hypothesis = self._attend_feedforward(embedded_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) similarity_matrix = self._matrix_attention(projected_premise, projected_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask) # Shape: (batch_size, premise_length, embedding_dim) attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention) # Shape: (batch_size, hypothesis_length, premise_length) h2p_attention = last_dim_softmax( similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Shape: (batch_size, hypothesis_length, embedding_dim) attended_premise = weighted_sum(embedded_premise, h2p_attention) premise_compare_input = torch.cat( [embedded_premise, attended_hypothesis], dim=-1) hypothesis_compare_input = torch.cat( [embedded_hypothesis, attended_premise], dim=-1) compared_premise = self._compare_feedforward(premise_compare_input) compared_premise = compared_premise * premise_mask.unsqueeze(-1) # Shape: (batch_size, compare_dim) compared_premise = compared_premise.sum(dim=1) compared_hypothesis = self._compare_feedforward( hypothesis_compare_input) compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze( -1) # Shape: (batch_size, compare_dim) compared_hypothesis = compared_hypothesis.sum(dim=1) aggregate_input = torch.cat([compared_premise, compared_hypothesis], dim=-1) label_logits = self._aggregate_feedforward(aggregate_input) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = { "label_logits": label_logits, "label_probs": label_probs, "h2p_attention": h2p_attention, "p2h_attention": p2h_attention } if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label.squeeze(-1)) output_dict["loss"] = loss return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, sentence_spans: torch.IntTensor = None, sent_labels: torch.IntTensor = None, q_type: torch.IntTensor = None, sp_mask: torch.IntTensor = None, coref_mask: torch.FloatTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: embedded_question = self._text_field_embedder(question) embedded_passage = self._text_field_embedder(passage) ques_mask = util.get_text_field_mask(question).float() context_mask = util.get_text_field_mask(passage).float() ques_output = self._dropout( self._phrase_layer(embedded_question, ques_mask)) context_output = self._dropout( self._phrase_layer(embedded_passage, context_mask)) modeled_passage, qc_score = self.qc_att(context_output, ques_output, ques_mask) modeled_passage = self.linear_1(modeled_passage) modeled_passage = self._modeling_layer(modeled_passage, context_mask) batch_size = modeled_passage.size()[0] output_start = self._span_start_encoder(modeled_passage, context_mask) span_start_logits = self.linear_start(output_start).squeeze( 2) - 1e30 * (1 - context_mask) output_end = torch.cat([modeled_passage, output_start], dim=2) output_end = self._span_end_encoder(output_end, context_mask) span_end_logits = self.linear_end(output_end).squeeze( 2) - 1e30 * (1 - context_mask) output_type = torch.cat([modeled_passage, output_end, output_start], dim=2) output_type = torch.max(output_type, 1)[0] predict_type = self.linear_type(output_type) type_predicts = torch.argmax(predict_type, 1) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "span_start_logits": span_start_logits, "span_end_logits": span_end_logits, "best_span": best_span, "qc_score": qc_score } # Compute the loss for training. if span_start is not None: try: start_loss = nll_loss( util.masked_log_softmax(span_start_logits, None), span_start.squeeze(-1)) end_loss = nll_loss( util.masked_log_softmax(span_end_logits, None), span_end.squeeze(-1)) type_loss = nll_loss( util.masked_log_softmax(predict_type, None), q_type) loss = start_loss + end_loss + type_loss self._loss_trackers['loss'](loss) self._loss_trackers['start_loss'](start_loss) self._loss_trackers['end_loss'](end_loss) self._loss_trackers['type_loss'](type_loss) output_dict["loss"] = loss except RuntimeError: print('\n meta_data:', metadata) print(span_start_logits.shape) # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] output_dict['answer_texts'] = [] question_tokens = [] passage_tokens = [] token_spans_sp = [] token_spans_sent = [] sent_labels_list = [] coref_clusters = [] ids = [] count_yes = 0 count_no = 0 for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) token_spans_sp.append(metadata[i]['token_spans_sp']) token_spans_sent.append(metadata[i]['token_spans_sent']) sent_labels_list.append(metadata[i]['sent_labels']) coref_clusters.append(metadata[i]['coref_clusters']) ids.append(metadata[i]['_id']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] if type_predicts[i] == 1: best_span_string = 'yes' count_yes += 1 elif type_predicts[i] == 2: best_span_string = 'no' count_no += 1 else: predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) output_dict['answer_texts'].append(answer_texts) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens output_dict['token_spans_sp'] = token_spans_sp output_dict['token_spans_sent'] = token_spans_sent output_dict['sent_labels'] = sent_labels_list output_dict['coref_clusters'] = coref_clusters output_dict['_id'] = ids return output_dict
def forward( # type: ignore self, question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question tokens, passage tokens, original passage text, and token offsets into the passage for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer( self._text_field_embedder(question)) embedded_passage = self._highway_layer( self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question) passage_mask = util.get_text_field_mask(passage) question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax( passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat( [ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector, ], dim=-1, ) modeled_passage = self._dropout( self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout( torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze( 1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat( [ final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation, ], dim=-1, ) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout( self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout( torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.cat([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict["best_span_str"] = [] question_tokens = [] passage_tokens = [] token_offsets = [] for i in range(batch_size): question_tokens.append(metadata[i]["question_tokens"]) passage_tokens.append(metadata[i]["passage_tokens"]) token_offsets.append(metadata[i]["token_offsets"]) passage_str = metadata[i]["original_passage"] offsets = metadata[i]["token_offsets"] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict["best_span_str"].append(best_span_string) answer_texts = metadata[i].get("answer_texts", []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict["question_tokens"] = question_tokens output_dict["passage_tokens"] = passage_tokens output_dict["token_offsets"] = token_offsets return output_dict
def forward(self, # type: ignore premise: Dict[str, torch.LongTensor], hypothesis: Dict[str, torch.LongTensor], label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- premise : Dict[str, torch.LongTensor] From a ``TextField`` hypothesis : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional (default = None) From a ``LabelField`` Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_premise = self._text_field_embedder(premise) embedded_hypothesis = self._text_field_embedder(hypothesis) premise_mask = get_text_field_mask(premise).float() hypothesis_mask = get_text_field_mask(hypothesis).float() if self._premise_encoder: embedded_premise = self._premise_encoder(embedded_premise, premise_mask) if self._hypothesis_encoder: embedded_hypothesis = self._hypothesis_encoder(embedded_hypothesis, hypothesis_mask) projected_premise = self._attend_feedforward(embedded_premise) projected_hypothesis = self._attend_feedforward(embedded_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) similarity_matrix = self._matrix_attention(projected_premise, projected_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask) # Shape: (batch_size, premise_length, embedding_dim) attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention) # Shape: (batch_size, hypothesis_length, premise_length) h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Shape: (batch_size, hypothesis_length, embedding_dim) attended_premise = weighted_sum(embedded_premise, h2p_attention) premise_compare_input = torch.cat([embedded_premise, attended_hypothesis], dim=-1) hypothesis_compare_input = torch.cat([embedded_hypothesis, attended_premise], dim=-1) compared_premise = self._compare_feedforward(premise_compare_input) compared_premise = compared_premise * premise_mask.unsqueeze(-1) # Shape: (batch_size, compare_dim) compared_premise = compared_premise.sum(dim=1) compared_hypothesis = self._compare_feedforward(hypothesis_compare_input) compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(-1) # Shape: (batch_size, compare_dim) compared_hypothesis = compared_hypothesis.sum(dim=1) aggregate_input = torch.cat([compared_premise, compared_hypothesis], dim=-1) label_logits = self._aggregate_feedforward(aggregate_input) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = {"label_logits": label_logits, "label_probs": label_probs} if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label.squeeze(-1)) output_dict["loss"] = loss return output_dict
def forward(self, question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata=None) -> Dict[str, torch.Tensor]: # shape: B x Tq x E embedded_question = self._embedder(question) embedded_passage = self._embedder(passage) batch_size = embedded_question.size(0) total_passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question) passage_mask = util.get_text_field_mask(passage) # shape: B x T x 2H encoded_question = self._dropout( self._question_encoder(embedded_question, question_mask)) encoded_passage = self._dropout( self._passage_encoder(embedded_passage, passage_mask)) passage_mask = passage_mask.float() question_mask = question_mask.float() encoding_dim = encoded_question.size(-1) # shape: B x 2H if encoded_passage.is_cuda: cuda_device = encoded_passage.get_device() gru_hidden = Variable( torch.zeros(batch_size, encoding_dim).cuda(cuda_device)) else: gru_hidden = Variable(torch.zeros(batch_size, encoding_dim)) question_awared_passage = [] for timestep in range(total_passage_length): # shape: B x Tq = attention(B x 2H, B x Tq x 2H) attn_weights = self._question_attention_for_passage( encoded_passage[:, timestep, :], encoded_question, question_mask) # shape: B x 2H = weighted_sum(B x Tq x 2H, B x Tq) attended_question = util.weighted_sum(encoded_question, attn_weights) # shape: B x 4H passage_question_combined = torch.cat( [encoded_passage[:, timestep, :], attended_question], dim=-1) # shape: B x 4H gate = F.sigmoid(self._gate(passage_question_combined)) gru_input = gate * passage_question_combined # shape: B x 2H gru_hidden = self._dropout(self._gru_cell(gru_input, gru_hidden)) question_awared_passage.append(gru_hidden) # shape: B x T x 2H # question aware passage representation v_P question_awared_passage = torch.stack(question_awared_passage, dim=1) self_attended_passage = [] for timestep in range(total_passage_length): attn_weights = self._passage_self_attention( question_awared_passage[:, timestep, :], question_awared_passage, passage_mask) attended_passage = util.weighted_sum(question_awared_passage, attn_weights) input_combined = torch.cat( [question_awared_passage[:, timestep, :], attended_passage], dim=-1) gate = F.sigmoid(self._self_gate(input_combined)) gru_input = gate * input_combined gru_hidden = self._dropout(self._gru_cell(gru_input, gru_hidden)) self_attended_passage.append(gru_hidden) self_attended_passage = torch.stack(self_attended_passage, dim=1) # compute question vector r_Q # shape: B x T = attention(B x 2H, B x T x 2H) v_r_Q_tiled = self._v_r_Q.unsqueeze(0).expand(batch_size, encoding_dim) attn_weights = self._question_attention_for_question( v_r_Q_tiled, encoded_question, question_mask) # shape: B x 2H r_Q = util.weighted_sum(encoded_question, attn_weights) # shape: B x T = attention(B x 2H, B x T x 2H) span_start_logits = self._passage_attention_for_answer( r_Q, self_attended_passage, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) span_start_log_probs = util.masked_log_softmax(span_start_logits, passage_mask) # shape: B x 2H c_t = util.weighted_sum(self_attended_passage, span_start_probs) # shape: B x 2H h_1 = self._dropout(self._answer_net(c_t, r_Q)) span_end_logits = self._passage_attention_for_answer( h_1, self_attended_passage, passage_mask) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_end_log_probs = util.masked_log_softmax(span_end_logits, passage_mask) best_span = self.get_best_span(span_start_logits, span_end_logits) #num_passages = passages_length.size(1) #acc = Variable(torch.zeros(batch_size, num_passages + 1)).cuda(cuda_device).long() #acc[:, 1:num_passages+1] = torch.cumsum(passages_length, dim=1) #g_batch = [] #for b in range(batch_size): # g = [] # for i in range(num_passages): # if acc[b, i+1].data[0] > acc[b, i].data[0]: # attn_weights = self._passage_attention_for_ranking(r_Q[b:b+1], question_awared_passage[b:b+1, acc[b, i].data[0]: acc[b, i+1].data[0], :], passage_mask[b:b+1, acc[b, i].data[0]: acc[b, i+1].data[0]]) # r_P = util.weighted_sum(question_awared_passage[b:b+1, acc[b, i].data[0]:acc[b, i+1].data[0], :], attn_weights) # question_passage_combined = torch.cat([r_Q[b:b+1], r_P], dim=-1) # gi = self._dropout(self._match_layer_2(F.tanh(self._dropout(self._match_layer_1(question_passage_combined))))) # g.append(gi) # else: # g.append(Variable(torch.zeros(1, 1)).cuda(cuda_device)) # g = torch.cat(g, dim=1) # g_batch.append(g) #t2 = time.time() #g = torch.cat(g_batch, dim=0) output_dict = {} if span_start is not None: AP_loss = F.nll_loss(span_start_log_probs, span_start.squeeze(-1)) +\ F.nll_loss(span_end_log_probs, span_end.squeeze(-1)) #PR_loss = F.nll_loss(passage_log_probs, correct_passage.squeeze(-1)) #loss = self._r * AP_loss + self._r * PR_loss self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict['loss'] = AP_loss _, max_start = torch.max(span_start_probs, dim=1) _, max_end = torch.max(span_end_probs, dim=1) #t3 = time.time() output_dict['span_start_idx'] = max_start output_dict['span_end_idx'] = max_end #t4 = time.time() #global ITE #ITE += 1 #if (ITE % 100 == 0): # print(" gold %i:%i|predicted %i:%i" %(span_start.squeeze(-1)[0], span_end.squeeze(-1)[0], max_start.data[0], max_end.data[0])) if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].data.cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens #t5 = time.time() #print("Total: %.5f" % (t5-t0)) #print("Batch processing 1: %.5f" % (t2-t1)) #print("Batch processing 2: %.5f" % (t4-t3)) return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, get_sample_level_information=True) -> Dict[str, torch.Tensor]: """ WE LOAD THE MODELS ONE INTO GPU ONE AT A TIME !!! """ subresults = [] for submodel in self.submodels: submodel.to(device=submodel.cf_a.device) subres = submodel(question, passage, span_start, span_end, metadata, get_sample_level_information) submodel.to(device=torch.device("cpu")) subresults.append(subres) batch_size = len(subresults[0]["best_span"]) best_span = merge_span_probs(subresults) output = { "best_span": best_span, "best_span_str": [], "models_output": subresults } if (get_sample_level_information): output["em_samples"] = [] output["f1_samples"] = [] for index in range(batch_size): if metadata is not None: passage_str = metadata[index]['original_passage'] offsets = metadata[index]['token_offsets'] predicted_span = tuple(best_span[index].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output["best_span_str"].append(best_span_string) answer_texts = metadata[index].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) if (get_sample_level_information): em_sample, f1_sample = bidut.get_em_f1_metrics( best_span_string, answer_texts) output["em_samples"].append(em_sample) output["f1_samples"].append(f1_sample) if (get_sample_level_information): # Add information about the individual samples for future analysis output["span_start_sample_loss"] = [] output["span_end_sample_loss"] = [] for i in range(batch_size): span_start_probs = sum( subresult['span_start_probs'] for subresult in subresults) / len(subresults) span_end_probs = sum( subresult['span_end_probs'] for subresult in subresults) / len(subresults) span_start_loss = nll_loss(span_start_probs[[i], :], span_start.squeeze(-1)[[i]]) span_end_loss = nll_loss(span_end_probs[[i], :], span_end.squeeze(-1)[[i]]) output["span_start_sample_loss"].append( float(span_start_loss.detach().cpu().numpy())) output["span_end_sample_loss"].append( float(span_end_loss.detach().cpu().numpy())) return output
def forward( self, # type: ignore premise: Dict[str, torch.LongTensor], hypothesis: Dict[str, torch.LongTensor], label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- premise : Dict[str, torch.LongTensor] From a ``TextField`` hypothesis : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional (default = None) From a ``LabelField`` Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_premise = self._text_field_embedder(premise) embedded_premise = self._embeddings_dropout(embedded_premise) embedded_hypothesis = self._text_field_embedder(hypothesis) embedded_hypothesis = self._embeddings_dropout(embedded_hypothesis) premise_mask = get_text_field_mask(premise).float() hypothesis_mask = get_text_field_mask(hypothesis).float() if self._premise_encoder: embedded_premise = self._premise_encoder(embedded_premise, premise_mask) embedded_premise = seq2vec_seq_aggregate( embedded_premise, premise_mask, self._premise_aggregate, self._premise_encoder.is_bidirectional(), 1) if self._hypothesis_encoder: embedded_hypothesis = self._hypothesis_encoder( embedded_hypothesis, hypothesis_mask) embedded_hypothesis = seq2vec_seq_aggregate( embedded_hypothesis, hypothesis_mask, self._hypothesis_aggregate, self._premise_encoder.is_bidirectional(), 1) aggregate_input = torch.cat([ embedded_premise, embedded_hypothesis, torch.abs(embedded_hypothesis - embedded_premise), embedded_hypothesis * embedded_hypothesis ], dim=-1) label_logits = self._aggregate_feedforward(aggregate_input) label_probs = torch.nn.functional.softmax(label_logits) output_dict = { "label_logits": label_logits, "label_probs": label_probs } if label is not None: labels = label.long().view(-1) loss = self._loss(label_logits, labels) self._accuracy(label_logits, label.squeeze(-1)) output_dict["loss"] = loss return output_dict
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer(self._text_field_embedder(question)) embedded_passage = self._highway_layer(self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector], dim=-1) modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) # answer_len for masking answer_len = [len(elem['answer_texts']) for elem in metadata] if metadata is not None else [] if answer_len: mask = torch.zeros((batch_size, max(answer_len), 2)).long() for index, length in enumerate(answer_len): mask[index, :length] = 1 else: mask = None best_span = self.get_best_span(span_start_logits, span_end_logits, answer_len) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: span_start = span_start.squeeze(-1) #batch X max_answer_L span_end = span_end.squeeze(-1) #batch X max_answer_L # TODO answer padding needs to be ignored step = 0 span_start_1D = span_start[ : , step:step + 1] #batch X 1 span_end_1D = span_end[ : , step:step + 1] #batch X 1 loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start_1D.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start_1D.squeeze(-1)) #TODO loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end_1D.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end_1D.squeeze(-1)) #TODO # self._span_accuracy(best_span, torch.stack([span_start_1D, span_end_1D], -1))#TODO for step in range(1, span_start.size(1)): span_start_1D = span_start[ : , step:step + 1] #batch X 1 span_end_1D = span_end[ : , step:step + 1] #batch X 1 loss += nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start_1D.squeeze(-1), ignore_index=-1) self._span_start_accuracy(span_start_logits, span_start_1D.squeeze(-1)) #TODO loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end_1D.squeeze(-1), ignore_index=-1) self._span_end_accuracy(span_end_logits, span_end_1D.squeeze(-1)) #TODO # self._span_accuracy(best_span, torch.stack([span_start_1D, span_end_1D], -1))#TODO self._span_accuracy(best_span, torch.stack([span_start, span_end], -1), mask) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): best_span_strings = [] question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_spans = tuple(best_span[i].data.cpu().numpy()) for predicted_span in predicted_spans: if predicted_span[0] == -1: break start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] best_span_strings.append(best_span_string) output_dict['best_span_str'].append(best_span_strings) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_strings, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
def forward_ensemble(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, get_sample_level_information = False) -> Dict[str, torch.Tensor]: """ Sample 10 times and add them together """ self.set_posterior_mean(True) most_likely_output = self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information) self.set_posterior_mean(False) subresults = [most_likely_output] for i in range(10): subresults.append(self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information)) batch_size = len(subresults[0]["best_span"]) best_span = bidut.merge_span_probs(subresults) output = { "best_span": best_span, "best_span_str": [], "models_output": subresults } if (get_sample_level_information): output["em_samples"] = [] output["f1_samples"] = [] for index in range(batch_size): if metadata is not None: passage_str = metadata[index]['original_passage'] offsets = metadata[index]['token_offsets'] predicted_span = tuple(best_span[index].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output["best_span_str"].append(best_span_string) answer_texts = metadata[index].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) if (get_sample_level_information): em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts) output["em_samples"].append(em_sample) output["f1_samples"].append(f1_sample) if (get_sample_level_information): # Add information about the individual samples for future analysis output["span_start_sample_loss"] = [] output["span_end_sample_loss"] = [] for i in range (batch_size): span_start_probs = sum(subresult['span_start_probs'] for subresult in subresults) / len(subresults) span_end_probs = sum(subresult['span_end_probs'] for subresult in subresults) / len(subresults) span_start_loss = nll_loss(span_start_probs[[i],:], span_start.squeeze(-1)[[i]]) span_end_loss = nll_loss(span_end_probs[[i],:], span_end.squeeze(-1)[[i]]) output["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy())) output["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy())) return output
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. The ending position is `exclusive`, so our :class:`~allennlp.data.dataset_readers.SquadReader` adds a special ending token to the end of the passage, to allow for the last token to be included in the answer span. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `exclusive` index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log probabilities of the span end position (exclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)``. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer(self._text_field_embedder(question)) embedded_passage = self._highway_layer(self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector], dim=-1) modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self._get_best_span(span_start_logits, span_end_logits) output_dict = {"span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span} if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss if metadata is not None and self._official_eval_dataset: output_dict['best_span_str'] = [] for i in range(batch_size): predicted_span = tuple(best_span[i].data.cpu().numpy()) best_span_string = self._compute_official_metrics(metadata[i], predicted_span) # type: ignore output_dict['best_span_str'].append(best_span_string) return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, for_training: bool = False) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer( self._text_field_embedder(question)) embedded_passage = self._highway_layer( self._text_field_embedder(passage)) ###################################### selection ###################################### pbx = sigmoid(self.linear(embedded_passage).squeeze(2)) print(' passage: ', passage) # # pby = self.we_selector(embedded_y1) # assert pbx.size() == passage['tokens'].size() # torch byte tesnor Variable of size (batch x len) selection_x = pbx.bernoulli().long() #(pbx>=threshold).long() # # selection_y = pby.bernoulli().long()#(pby>=threshold).long() result_x = passage['tokens'].mul( selection_x ) #word ids that are selected; contains zeros where it's not selected (ony selected can be found by selected_x[selected_x!=0]) char_result_x = passage['token_characters'] * selection_x.unsqueeze( 2).repeat(1, 1, passage['token_characters'].size()[2]) # result_y = sentence2.mul(selection_y) # print('result_x: ', result_x) selected_x, char_selected_x = helper.get_selected_tensor( result_x, char_result_x, pbx, passage['tokens'], passage['token_characters'], self.cuda_device) #sentence1_len is a numpy array print(' passage size: ', passage['tokens'], ' char_passage size: ', passage['token_characters'], ' selected_x: ', selected_x, ' char_selected_x: ', char_selected_x) # selected_y, sentence2_len = helper.get_selected_tensor(result_y, pby, sentence2, sentence2_len_old, self.config.cuda) #sentence2_len is a numpy array logpz = zsum = zdiff = -1.0 if for_training: mask1 = ( passage['tokens'] != self._vocab.get_token_index(DEFAULT_PADDING_TOKEN)).long() # mask2 = (sentence2!=0).long() masked_selection_x = selection_x.mul(mask1) # masked_selection_y = selection_y.mul(mask2) # #logpz (batch x len) logpx = -helper.binary_cross_entropy( pbx, selection_x.float().detach(), reduce=False ) #as reduce is not available for this version I am doing this code myself: # logpy = -helper.binary_cross_entropy(pby, selection_y.float().detach(), reduce = False) assert logpx.size() == passage['tokens'].size() # # batch logpx = logpx.mul(mask1.float()).sum(1) # logpy = logpy.mul(mask2.float()).sum(1) logpz = logpx #(logpx+logpy) # # zsum = ##### same as sentence1_len #####T.sum(z, axis=0, dtype=theano.config.floatX) zdiff1 = ( masked_selection_x[:, 1:] - masked_selection_x[:, :-1] ).abs().sum( 1 ) ####T.sum(T.abs_(z[1:]-z[:-1]), axis=0, dtype=theano.config.floatX) # zdiff2 = (masked_selection_y[:,1:]-masked_selection_y[:,:-1]).abs().sum(1) ####T.sum(T.abs_(z[1:]-z[:-1]), axis=0, dtype=theano.config.floatX) assert zdiff1.size()[0] == passage['tokens'].size()[0] # assert logpz.size()[0] == sentence1.size()[0] zdiff = zdiff1 #+zdiff2 xsum = masked_selection_x.sum(1) # ysum = masked_selection_y.sum(1) zsum = xsum #+ysum assert zsum.size()[0] == passage['tokens'].size()[0] assert logpz.dim() == zsum.dim() assert logpz.dim() == zdiff.dim() # return selected_x, sentence1_len, selected_y, sentence2_len, logpz, zsum.float(), zdiff.float() passage['tokens'] = selected_x passage['token_characters'] = char_selected_x # print(' passage[tokens]: ', passage['tokens'], ' dim: ', passage['tokens'].dim()) # print("selected_x: ", selected_x, ' dim: ', selected_x.dim()) embedded_passage = self._highway_layer( self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.last_dim_softmax( passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector ], dim=-1) modeled_passage = self._dropout( self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout( torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze( 1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([ final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation ], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout( self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout( torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].data.cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
def forward( self, # type: ignore premise: Dict[str, torch.LongTensor], hypothesis0: Dict[str, torch.LongTensor], hypothesis1: Dict[str, torch.LongTensor], hypothesis2: Dict[str, torch.LongTensor] = None, label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- premise : Dict[str, torch.LongTensor] From a ``TextField`` hypothesis : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional (default = None) From a ``LabelField`` Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ hyps = [ h for h in [hypothesis0, hypothesis1, hypothesis2] if h is not None ] if isinstance(self._text_field_embedder, ElmoTokenEmbedder): self._text_field_embedder._elmo._elmo_lstm._elmo_lstm.reset_states( ) embedded_premise = self._text_field_embedder(premise) embedded_hypotheses = [] for hypothesis in hyps: if isinstance(self._text_field_embedder, ElmoTokenEmbedder): self._text_field_embedder._elmo._elmo_lstm._elmo_lstm.reset_states( ) embedded_hypotheses.append(self._text_field_embedder(hypothesis)) premise_mask = get_text_field_mask(premise).float() hypothesis_masks = [ get_text_field_mask(hypothesis).float() for hypothesis in hyps ] # apply dropout for LSTM if self.rnn_input_dropout: embedded_premise = self.rnn_input_dropout(embedded_premise) embedded_hypotheses = [ self.rnn_input_dropout(hyp) for hyp in embedded_hypotheses ] # encode premise and hypothesis encoded_premise = self._encoder(embedded_premise, premise_mask) label_logits = [] for i, (embedded_hypothesis, hypothesis_mask) in enumerate( zip(embedded_hypotheses, hypothesis_masks)): encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask) # Shape: (batch_size, premise_length, hypothesis_length) similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask) # Shape: (batch_size, premise_length, embedding_dim) attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention) # Shape: (batch_size, hypothesis_length, premise_length) h2p_attention = masked_softmax( similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Shape: (batch_size, hypothesis_length, embedding_dim) attended_premise = weighted_sum(encoded_premise, h2p_attention) # the "enhancement" layer premise_enhanced = torch.cat([ encoded_premise, attended_hypothesis, encoded_premise - attended_hypothesis, encoded_premise * attended_hypothesis ], dim=-1) hypothesis_enhanced = torch.cat([ encoded_hypothesis, attended_premise, encoded_hypothesis - attended_premise, encoded_hypothesis * attended_premise ], dim=-1) # embedding -> lstm w/ do -> enhanced attention -> dropout_proj, only if ELMO -> ff proj -> lstm w/ do -> dropout -> ff 300 -> dropout -> output # add dropout here with ELMO # the projection layer down to the model dimension # no dropout in projection projected_enhanced_premise = self._projection_feedforward( premise_enhanced) projected_enhanced_hypothesis = self._projection_feedforward( hypothesis_enhanced) # Run the inference layer if self.rnn_input_dropout: projected_enhanced_premise = self.rnn_input_dropout( projected_enhanced_premise) projected_enhanced_hypothesis = self.rnn_input_dropout( projected_enhanced_hypothesis) v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask) v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask) # The pooling layer -- max and avg pooling. # (batch_size, model_dim) v_a_max, _ = replace_masked_values(v_ai, premise_mask.unsqueeze(-1), -1e7).max(dim=1) v_b_max, _ = replace_masked_values(v_bi, hypothesis_mask.unsqueeze(-1), -1e7).max(dim=1) v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum( premise_mask, 1, keepdim=True) v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum( hypothesis_mask, 1, keepdim=True) # Now concat # (batch_size, model_dim * 2 * 4) v = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1) # the final MLP -- apply dropout to input, and MLP applies to output & hidden if self.dropout: v = self.dropout(v) output_hidden = self._output_feedforward(v) logit = self._output_logit(output_hidden) assert logit.size(-1) == 1 label_logits.append(logit) label_logits = torch.cat(label_logits, -1) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = { "label_logits": label_logits, "label_probs": label_probs } if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label.squeeze(-1)) output_dict["loss"] = loss return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, sentence_spans: torch.IntTensor = None, sent_labels: torch.IntTensor = None, evd_chain_labels: torch.IntTensor = None, q_type: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: if self._sent_labels_src == 'chain': batch_size, num_spans = sent_labels.size() sent_labels_mask = (sent_labels >= 0).float() print("chain:", evd_chain_labels) # we use the chain as the label to supervise the gate # In this model, we only take the first chain in ``evd_chain_labels`` for supervision, # right now the number of chains should only be one too. evd_chain_labels = evd_chain_labels[:, 0].long() # build the gate labels. The dim is set to 1 + num_spans to account for the end embedding # shape: (batch_size, 1+num_spans) sent_labels = sent_labels.new_zeros((batch_size, 1 + num_spans)) sent_labels.scatter_(1, evd_chain_labels, 1.) # remove the column for end embedding # shape: (batch_size, num_spans) sent_labels = sent_labels[:, 1:].float() # make the padding be -1 sent_labels = sent_labels * sent_labels_mask + -1. * ( 1 - sent_labels_mask) # word + char embedding embedded_question = self._text_field_embedder(question) embedded_passage = self._text_field_embedder(passage) # mask ques_mask = util.get_text_field_mask(question).float() context_mask = util.get_text_field_mask(passage).float() # BiDAF for answer predicion ques_output = self._dropout( self._phrase_layer(embedded_question, ques_mask)) context_output = self._dropout( self._phrase_layer(embedded_passage, context_mask)) modeled_passage, _, qc_score = self.qc_att(context_output, ques_output, ques_mask) modeled_passage = self._modeling_layer(modeled_passage, context_mask) # BiDAF for gate prediction ques_output_sp = self._dropout( self._phrase_layer_sp(embedded_question, ques_mask)) context_output_sp = self._dropout( self._phrase_layer_sp(embedded_passage, context_mask)) modeled_passage_sp, _, qc_score_sp = self.qc_att_sp( context_output_sp, ques_output_sp, ques_mask) modeled_passage_sp = self._modeling_layer_sp(modeled_passage_sp, context_mask) # gate prediction # Shape(spans_rep): (batch_size * num_spans, max_batch_span_width, embedding_dim) # Shape(spans_mask): (batch_size, num_spans, max_batch_span_width) spans_rep_sp, spans_mask = convert_sequence_to_spans( modeled_passage_sp, sentence_spans) spans_rep, _ = convert_sequence_to_spans(modeled_passage, sentence_spans) # Shape(gate_logit): (batch_size * num_spans, 2) # Shape(gate): (batch_size * num_spans, 1) # Shape(pred_sent_probs): (batch_size * num_spans, 2) # Shape(gate_mask): (batch_size, num_spans) #gate_logit, gate, pred_sent_probs = self._span_gate(spans_rep_sp, spans_mask) gate_logit, gate, pred_sent_probs, gate_mask, g_att_score = self._span_gate( spans_rep_sp, spans_mask, self._gate_self_attention_layer, self._gate_sent_encoder) batch_size, num_spans, max_batch_span_width = spans_mask.size() strong_sup_loss = F.nll_loss( F.log_softmax(gate_logit, dim=-1).view(batch_size * num_spans, -1), sent_labels.long().view(batch_size * num_spans), ignore_index=-1) gate = (gate >= 0.3).long() spans_rep = spans_rep * gate.unsqueeze(-1).float() attended_sent_embeddings = convert_span_to_sequence( modeled_passage_sp, spans_rep, spans_mask) modeled_passage = attended_sent_embeddings + modeled_passage self_att_passage = self._self_attention_layer(modeled_passage, mask=context_mask) modeled_passage = modeled_passage + self_att_passage[0] self_att_score = self_att_passage[2] output_start = self._span_start_encoder(modeled_passage, context_mask) span_start_logits = self.linear_start(output_start).squeeze( 2) - 1e30 * (1 - context_mask) output_end = torch.cat([modeled_passage, output_start], dim=2) output_end = self._span_end_encoder(output_end, context_mask) span_end_logits = self.linear_end(output_end).squeeze( 2) - 1e30 * (1 - context_mask) output_type = torch.cat([modeled_passage, output_end, output_start], dim=2) output_type = torch.max(output_type, 1)[0] # output_type = torch.max(self.rnn_type(output_type, context_mask), 1)[0] predict_type = self.linear_type(output_type) type_predicts = torch.argmax(predict_type, 1) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "span_start_logits": span_start_logits, "span_end_logits": span_end_logits, "best_span": best_span, "pred_sent_labels": gate.view(batch_size, num_spans), #[B, num_span] "gate_probs": pred_sent_probs[:, 1].view(batch_size, num_spans), #[B, num_span] } if self._output_att_scores: if not qc_score is None: output_dict['qc_score'] = qc_score if not qc_score_sp is None: output_dict['qc_score_sp'] = qc_score_sp if not self_att_score is None: output_dict['self_attention_score'] = self_att_score if not g_att_score is None: output_dict['evd_self_attention_score'] = g_att_score print("sent label:") for b_label in np.array(sent_labels.cpu()): b_label = b_label == 1 indices = np.arange(len(b_label)) print(indices[b_label] + 1) # Compute the loss for training. if span_start is not None: try: start_loss = nll_loss( util.masked_log_softmax(span_start_logits, None), span_start.squeeze(-1)) end_loss = nll_loss( util.masked_log_softmax(span_end_logits, None), span_end.squeeze(-1)) type_loss = nll_loss( util.masked_log_softmax(predict_type, None), q_type) loss = start_loss + end_loss + type_loss + strong_sup_loss self._loss_trackers['loss'](loss) self._loss_trackers['start_loss'](start_loss) self._loss_trackers['end_loss'](end_loss) self._loss_trackers['type_loss'](type_loss) self._loss_trackers['strong_sup_loss'](strong_sup_loss) output_dict["loss"] = loss except RuntimeError: print('\n meta_data:', metadata) print(span_start_logits.shape) # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] output_dict['answer_texts'] = [] question_tokens = [] passage_tokens = [] token_spans_sp = [] token_spans_sent = [] sent_labels_list = [] evd_possible_chains = [] ans_sent_idxs = [] ids = [] count_yes = 0 count_no = 0 for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) token_spans_sp.append(metadata[i]['token_spans_sp']) token_spans_sent.append(metadata[i]['token_spans_sent']) sent_labels_list.append(metadata[i]['sent_labels']) ids.append(metadata[i]['_id']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] if type_predicts[i] == 1: best_span_string = 'yes' count_yes += 1 elif type_predicts[i] == 2: best_span_string = 'no' count_no += 1 else: predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) output_dict['answer_texts'].append(answer_texts) if answer_texts: self._squad_metrics(best_span_string.lower(), answer_texts) # shift sentence indice back evd_possible_chains.append([ s_idx - 1 for s_idx in metadata[i]['evd_possible_chains'][0] if s_idx > 0 ]) ans_sent_idxs.append( [s_idx - 1 for s_idx in metadata[i]['ans_sent_idxs']]) self._f1_metrics(pred_sent_probs, sent_labels.view(-1), gate_mask.view(-1)) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens output_dict['token_spans_sp'] = token_spans_sp output_dict['token_spans_sent'] = token_spans_sent output_dict['sent_labels'] = sent_labels_list output_dict['evd_possible_chains'] = evd_possible_chains output_dict['ans_sent_idxs'] = ans_sent_idxs output_dict['_id'] = ids return output_dict
def forward(self, question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], passages_length: torch.LongTensor = None, correct_passage: torch.LongTensor = None, span_start: torch.IntTensor = None, span_end: torch.IntTensor = None) -> Dict[str, torch.Tensor]: # shape: B x Tq x E embedded_question = self._embedder(question) embedded_passage = self._embedder(passage) batch_size = embedded_question.size(0) total_passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question) passage_mask = util.get_text_field_mask(passage) # shape: B x T x 2H encoded_question = self._dropout( self._question_encoder(embedded_question, question_mask)) encoded_passage = self._dropout( self._passage_encoder(embedded_passage, passage_mask)) passage_mask = passage_mask.float() question_mask = question_mask.float() encoding_dim = encoded_question.size(-1) # shape: B x 2H if encoded_passage.is_cuda: cuda_device = encoded_passage.get_device() gru_hidden = Variable( torch.zeros(batch_size, encoding_dim).cuda(cuda_device)) else: gru_hidden = Variable(torch.zeros(batch_size, encoding_dim)) question_awared_passage = [] for timestep in range(total_passage_length): # shape: B x Tq = attention(B x 2H, B x Tq x 2H) attn_weights = self._question_attention_for_passage( encoded_passage[:, timestep, :], encoded_question, question_mask) # shape: B x 2H = weighted_sum(B x Tq x 2H, B x Tq) attended_question = util.weighted_sum(encoded_question, attn_weights) # shape: B x 4H passage_question_combined = torch.cat( [encoded_passage[:, timestep, :], attended_question], dim=-1) # shape: B x 4H gate = F.sigmoid(self._gate(passage_question_combined)) gru_input = gate * passage_question_combined # shape: B x 2H gru_hidden = self._dropout(self._gru_cell(gru_input, gru_hidden)) question_awared_passage.append(gru_hidden) # shape: B x T x 2H # question aware passage representation v_P question_awared_passage = torch.stack(question_awared_passage, dim=1) # compute question vector r_Q # shape: B x T = attention(B x 2H, B x T x 2H) v_r_Q_tiled = self._v_r_Q.unsqueeze(0).expand(batch_size, encoding_dim) attn_weights = self._question_attention_for_question( v_r_Q_tiled, encoded_question, question_mask) # shape: B x 2H r_Q = util.weighted_sum(encoded_question, attn_weights) # shape: B x T = attention(B x 2H, B x T x 2H) span_start_logits = self._passage_attention_for_answer( r_Q, question_awared_passage, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) span_start_log_probs = util.masked_log_softmax(span_start_logits, passage_mask) # shape: B x 2H c_t = util.weighted_sum(question_awared_passage, span_start_probs) # shape: B x 2H h_1 = self._dropout(self._answer_net(c_t, r_Q)) span_end_logits = self._passage_attention_for_answer( h_1, question_awared_passage, passage_mask) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_end_log_probs = util.masked_log_softmax(span_end_logits, passage_mask) num_passages = passages_length.size(1) acc = Variable(torch.zeros(batch_size, num_passages + 1)).cuda(cuda_device).long() acc[:, 1:num_passages + 1] = torch.cumsum(passages_length, dim=1) g_batch = [] for b in range(batch_size): g = [] for i in range(num_passages): if acc[b, i + 1].data[0] > acc[b, i].data[0]: attn_weights = self._passage_attention_for_ranking( r_Q[b:b + 1], question_awared_passage[b:b + 1, acc[ b, i].data[0]:acc[b, i + 1].data[0], :], passage_mask[b:b + 1, acc[b, i].data[0]:acc[b, i + 1].data[0]]) r_P = util.weighted_sum( question_awared_passage[b:b + 1, acc[b, i].data[0]:acc[ b, i + 1].data[0], :], attn_weights) question_passage_combined = torch.cat([r_Q[b:b + 1], r_P], dim=-1) gi = self._dropout( self._match_layer_2( F.tanh( self._dropout( self._match_layer_1( question_passage_combined))))) g.append(gi) else: g.append(Variable(torch.zeros(1, 1)).cuda(cuda_device)) g = torch.cat(g, dim=1) g_batch.append(g) #t2 = time.time() g = torch.cat(g_batch, dim=0) passage_log_probs = F.log_softmax(g, dim=-1) output_dict = {} if span_start is not None: AP_loss = F.nll_loss(span_start_log_probs, span_start.squeeze(-1)) +\ F.nll_loss(span_end_log_probs, span_end.squeeze(-1)) PR_loss = F.nll_loss(passage_log_probs, correct_passage.squeeze(-1)) loss = self._r * AP_loss + self._r * PR_loss output_dict['loss'] = loss _, max_start = torch.max(span_start_probs, dim=1) _, max_end = torch.max(span_end_probs, dim=1) #t3 = time.time() output_dict['span_start_idx'] = max_start output_dict['span_end_idx'] = max_end #t4 = time.time() global ITE ITE += 1 if (ITE % 100 == 0): print(" gold %i:%i|predicted %i:%i" % (span_start.squeeze(-1)[0], span_end.squeeze(-1)[0], max_start.data[0], max_end.data[0])) #t5 = time.time() #print("Total: %.5f" % (t5-t0)) #print("Batch processing 1: %.5f" % (t2-t1)) #print("Batch processing 2: %.5f" % (t4-t3)) return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], choices: Dict[str, torch.LongTensor], evidence: Dict[str, torch.LongTensor], answer_index: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None # pylint:disable=unused-argument ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- qa_pairs : Dict[str, torch.LongTensor] From a ``ListField``. answer_index : ``torch.IntTensor``, optional From an ``IndexField``. This is what we are trying to predict. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, question and choices for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys ``qid``, ``question``, ``choices``, ``question_tokens`` and ``choices_tokens``. Returns ------- An output dictionary consisting of the followings. qid : List[str] A list consisting of question ids. answer_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_options=5)`` representing unnormalised log probabilities of the choices. answer_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_options=5)`` representing probabilities of the choices. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ # batch, seq_len -> batch, seq_len, emb question_hidden = self._bert(question) batch_size, emb_size = question_hidden.size(0), question_hidden.size(2) question_hidden = question_hidden[..., 0, :] # batch, emb # batch, 5, seq_len -> batch, 5, seq_len, emb choice_hidden = self._bert(choices, num_wrapping_dims=1) choice_hidden = choice_hidden[..., 0, :] # batch, 5, emb if 'first' in self.model_type: # batch, 5, evi_num, seq_len -> batch, 5, evi_num, seq_len, emb evidence_hidden = self._bert(evidence, num_wrapping_dims=2) # evi_num = evidence_hidden.size(2) # batch, 5, evi_num, emb evidence_hidden = evidence_hidden[..., 0, :] if self.dropout: question_hidden = self.dropout(question_hidden) choice_hidden = self.dropout(choice_hidden) if 'first' in self.model_type: evidence_hidden = self.dropout(evidence_hidden) if 'first' in self.model_type: evidence_summary = evidence_hidden[..., 0, :] question_hidden = question_hidden.unsqueeze(1).expand( batch_size, 5, emb_size) cls_hidden = torch.cat([question_hidden, choice_hidden], dim=-1) if 'first' in self.model_type: cls_hidden = torch.cat([cls_hidden, evidence_summary], dim=-1) # the final MLP -- apply dropout to input, and MLP applies to hidden answer_logits = self._classifier(cls_hidden).squeeze(-1) answer_probs = torch.nn.functional.softmax(answer_logits, dim=-1) qids = [m['qid'] for m in metadata] output_dict = { "answer_logits": answer_logits, "answer_probs": answer_probs, "qid": qids } if answer_index is not None: answer_index = answer_index.squeeze(-1) loss = self._loss(answer_logits, answer_index) self._accuracy(answer_logits, answer_index) output_dict["loss"] = loss return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer( self._text_field_embedder(question)) embedded_passage = self._highway_layer( self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax( passage_question_similarity, question_mask, dim=-1) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # Shape: (batch_size, passage_length) question_passage_similarity = torch.transpose( passage_question_similarity, 1, 2) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask, dim=-1) # Shape: (batch_size, question_length, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) passage_gate = torch.unsqueeze( self._passage_similarity_function(encoded_passage, passage_question_vectors), -1) passage_fusion = self._passage_fusion_function( encoded_passage, passage_question_vectors) gated_passage = passage_gate * passage_fusion + ( 1 - passage_gate) * encoded_passage question_gate = torch.unsqueeze( self._question_similarity_function(encoded_question, question_passage_vector), -1) question_fusion = self._question_fusion_function( encoded_question, question_passage_vector) gated_question = question_gate * question_fusion + ( 1 - question_gate) * encoded_question passage_passage_similarity = self._self_matrix_attention( gated_passage, gated_passage) passage_passage_attention = util.masked_softmax( passage_passage_similarity, passage_mask, dim=-1) passage_passage_vector = util.weighted_sum(gated_passage, passage_passage_attention) final_passage = self._fusion_function(gated_passage, passage_passage_vector) modeled_passage = self._dropout( self._passage_modeling_layer(final_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) span_logits = self._span_predictor(modeled_passage) modeled_question = self._question_modeling_layer( gated_question, question_lstm_mask) question_vector = self._question_encoding_layer( modeled_question, question_lstm_mask).unsqueeze(-1) span_start_logits = torch.bmm(self._span_start_weight(modeled_passage), question_vector).squeeze(-1) span_end_logits = torch.bmm(self._span_end_weight(modeled_passage), question_vector).squeeze(-1) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: device_id = util.get_device_of(span_start) weight = self._span_weight.cuda( device_id) if device_id >= 0 else self._span_weight arange_mask = util.get_range_vector(passage_length, util.get_device_of(span_start)) span_mask = (arange_mask >= span_start) & (arange_mask <= span_end) span_loss = nll_loss(self._masked_log_softmax( span_logits, passage_mask).transpose(1, 2), span_mask.long(), weight=weight) loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss + span_loss / 2 # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer( self._text_field_embedder(question)) embedded_passage = self._highway_layer( self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) question_mask_temp = question_mask question_mask_temp = question_mask_temp.unsqueeze_(1) #New Question SA encoding device0 = torch.device('cuda:0') device1 = torch.device('cuda:1') device2 = torch.device('cuda:2') device3 = torch.device('cuda:3') sa_question_sim = self._dropout( self._sa_matrix_attention(device1, embedded_question, None, True)) sa_question_att = util.masked_softmax(sa_question_sim.to(device1), question_mask_temp.to(device1)) sa_encoded_question = util.weighted_sum(encoded_question.to(device1), sa_question_att.to(device1)) sa_encoded_question = sa_encoded_question.to(device0) sa_passage_sim = self._dropout( self._sa_matrix_attention(device2, embedded_passage, None, True)) sa_passage_att = util.masked_softmax( sa_passage_sim.to(device1), passage_mask.clone().unsqueeze_(1).to(device1)) sa_encoded_passage = util.weighted_sum(encoded_passage.to(device1), sa_passage_att.to(device1)) sa_encoded_passage = sa_encoded_passage.to(device0) #sa_encoded_passage = encoded_passage # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) sa_passage_question_similarity = self._l_matrix_attention( device1, sa_encoded_passage, sa_encoded_question, False) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax( passage_question_similarity, question_mask_temp) sa_passage_question_attention = util.masked_softmax( sa_passage_question_similarity.to(device0), question_mask_temp) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) sa_passage_question_vectors = util.weighted_sum( sa_encoded_question.to(device0), sa_passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask, -1e7) sa_masked_similarity = util.replace_masked_values( sa_passage_question_similarity.to(device0), question_mask, -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) sa_question_passage_similarity = sa_masked_similarity.max( dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask) sa_question_passage_attention = util.masked_softmax( sa_question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) sa_question_passage_vector = util.weighted_sum( sa_encoded_passage, sa_question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) #print("Shape:",question_passage_vector.size(),question_passage_vector.unsqueeze(1).size()) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) #print("Shape:",sa_question_passage_vector.size(),sa_question_passage_vector.unsqueeze(1).size()) sa_tiled_question_passage_vector = sa_question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) #print("Shape of SA Encoded:",sa_encoded_question.size(),sa_encoded_passage.size()) #print("Required Shape of Encoded Passage:",encoded_passage.size(),passage_question_vectors.size()) #sa_passage_question_vectors = passage_question_vectors #sa_tiled_question_passage_vector = tiled_question_passage_vector # Shape: (batch_size, passage_length, encoding_dim * 4 + 4*sa_dim ) final_merged_passage = torch.cat([ encoded_passage, sa_encoded_passage, passage_question_vectors, sa_passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector, sa_encoded_passage * sa_passage_question_vectors, sa_encoded_passage * sa_tiled_question_passage_vector ], dim=-1) modeled_passage = self._dropout( self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim + 2*selfattention_dim)) span_start_input = self._dropout( torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) # Shape: (batch_size, 1000) span_start_logits_pad = (0, 1000 - passage_length, 0, 0) span_start_logits_w_na = self._span_start_predictor_w_na( pad(span_start_logits, span_start_logits_pad)).squeeze(-1) # Shape: (batch_size, passage_lenght+1) span_start_logits_w_na = span_start_logits_w_na[:, :passage_length + 1] span_start_na_logits = span_start_logits_w_na[:, 0] span_start_logits = span_start_logits_w_na[:, 1:] # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze( 1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3 + 4*sadim) span_end_representation = torch.cat([ final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation ], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout( self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim +4*sadim) span_end_input = self._dropout( torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) # Shape: (batch_size, passage_length+1) span_end_logits_pad = (0, 1000 - passage_length, 0, 0) span_end_logits_w_na = self._span_end_predictor_w_na( pad(span_end_logits, span_end_logits_pad)).squeeze(-1) # Shape: (batch_size, passage_lenght+1) span_end_logits_w_na = span_end_logits_w_na[:, :passage_length + 1] span_end_na_logits = span_end_logits_w_na[:, 0] span_end_logits = span_end_logits_w_na[:, 1:] span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) na_gt = (span_start == -1).type(torch.cuda.LongTensor) na_inv = (1.0 - na_gt) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, # "na_logits": na_logits, # "na_probs": na_probs } # Compute the loss for training. if span_start is not None: y_start = span_start + 1 y_end = span_end + 1 passage_mask_w_na = torch.cat([ torch.ones([batch_size, 1]).type(torch.cuda.FloatTensor), passage_mask ], -1) loss = 0.0 # calculate loss if there is answer # loss for start preds_start = util.masked_log_softmax( span_start_logits_w_na.type(torch.cuda.FloatTensor), passage_mask_w_na.type(torch.cuda.FloatTensor)).type( torch.cuda.FloatTensor) y_start = y_start.squeeze(-1).type(torch.cuda.LongTensor) loss += nll_loss(preds_start, y_start) # accuracy for start acc_p_start = na_inv.type( torch.cuda.FloatTensor) * span_start_logits.type( torch.cuda.FloatTensor) acc_y_start = na_inv.squeeze(-1).type( torch.cuda.FloatTensor) * span_start.squeeze(-1).type( torch.cuda.FloatTensor) self._span_start_accuracy(acc_p_start, acc_y_start) # loss for end preds_end = util.masked_log_softmax( span_end_logits_w_na.type(torch.cuda.FloatTensor), passage_mask_w_na.type(torch.cuda.FloatTensor)).type( torch.cuda.FloatTensor) y_end = y_end.squeeze(-1).type(torch.cuda.LongTensor) loss += nll_loss(preds_end, y_end) # accuracy for end acc_p_end = na_inv.type( torch.cuda.FloatTensor) * span_end_logits.type( torch.cuda.FloatTensor) acc_y_end = na_inv.squeeze(-1).type( torch.cuda.FloatTensor) * span_end.squeeze(-1).type( torch.cuda.FloatTensor) self._span_end_accuracy(acc_p_end, acc_y_end) # accuracy for span acc_p = na_inv.type(torch.cuda.FloatTensor) * best_span.type( torch.cuda.FloatTensor) acc_y = na_inv.type(torch.cuda.FloatTensor) * torch.cat([ span_start.type(torch.cuda.FloatTensor), span_end.type(torch.cuda.FloatTensor) ], -1) self._span_accuracy(acc_p, acc_y) output_dict["loss"] = loss preds_start = util.masked_softmax( span_start_logits_w_na.type(torch.cuda.FloatTensor), passage_mask_w_na.type(torch.cuda.FloatTensor)).type( torch.cuda.FloatTensor) preds_end = util.masked_softmax( span_end_logits_w_na.type(torch.cuda.FloatTensor), passage_mask_w_na.type(torch.cuda.FloatTensor)).type( torch.cuda.FloatTensor) output_dict["na_logits"] = preds_start[:, 0] * preds_end[:, 0] output_dict["na_probs"] = torch.stack( [1.0 - output_dict["na_logits"], output_dict["na_logits"]], -1) # calculate loss for answer existance self._na_accuracy( output_dict["na_probs"].type(torch.cuda.FloatTensor), na_gt.squeeze(-1).type(torch.cuda.FloatTensor)) # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
def forward( self, tokens_list: Dict[str, torch.LongTensor], positions_list: Dict[str, torch.LongTensor], sent_positions_list: Dict[str, torch.LongTensor], before_loc_start: torch.IntTensor = None, before_loc_end: torch.IntTensor = None, after_loc_start_list: torch.IntTensor = None, after_loc_end_list: torch.IntTensor = None, before_category: torch.IntTensor = None, after_category_list: torch.IntTensor = None, before_category_mask: torch.IntTensor = None, after_category_mask_list: torch.IntTensor = None ) -> Dict[str, torch.Tensor]: """ :param tokens_list: Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. :param positions_list: same as tokens_list :param sent_positions_list: same as tokens_list :param before_loc_start: torch.IntTensor = None, required An integer ``IndexField`` representation of the before location start :param before_loc_end: torch.IntTensor = None, required An integer ``IndexField`` representation of the before location end :param after_loc_start_list: torch.IntTensor = None, required A list of integers ``ListField (IndexField)`` representation of the list of after location starts along the sequence of steps :param after_loc_end_list: torch.IntTensor = None, required A list of integers ``ListField (IndexField)`` representation of the list of after location ends along the sequence of steps :param before_category: torch.IntTensor = None, required An integer ``IndexField`` representation of the before location category :param after_category_list: torch.IntTensor = None, required A list of integers ``ListField (IndexField)`` representation of the list of after location categories along the sequence of steps :param before_category_mask: torch.IntTensor = None, required An integer ``IndexField`` representation of whether the before location is known or not (0/1) :param after_category_mask_list: torch.IntTensor = None, required A list of integers ``ListField (IndexField)`` representation of the list of whether after location is known or not for each step along the sequence of steps :return: An output dictionary consisting of: best_span: torch.FloatTensor A tensor of shape ``()`` true_span: torch.FloatTensor loss: torch.FloatTensor """ # batchsize * listLength * paragraphSize * embeddingSize input_embedding_paragraph = self._text_field_embedder(tokens_list) input_pos_embedding_paragraph = self._pos_field_embedder( positions_list) input_sent_pos_embedding_paragraph = self._sent_pos_field_embedder( sent_positions_list) # batchsize * listLength * paragraphSize * (embeddingSize*2) embedding_paragraph = torch.cat([ input_embedding_paragraph, input_pos_embedding_paragraph, input_sent_pos_embedding_paragraph ], dim=-1) # batchsize * listLength * paragraphSize, this mask is shared with the text fields and sequence label fields para_mask = util.get_text_field_mask(tokens_list, num_wrapping_dims=1).float() # batchsize * listLength , this mask is shared with the index fields para_index_mask, para_index_mask_indices = torch.max(para_mask, 2) # apply mask to update the index values, padded instances will be 0 after_loc_start_list = (after_loc_start_list.float() * para_index_mask.unsqueeze(2)).long() after_loc_end_list = (after_loc_end_list.float() * para_index_mask.unsqueeze(2)).long() after_category_list = (after_category_list.float() * para_index_mask.unsqueeze(2)).long() after_category_mask_list = (after_category_mask_list.float() * para_index_mask.unsqueeze(2)).long() batch_size, list_size, paragraph_size, input_dim = embedding_paragraph.size( ) # to store the values passed to next step tmp_category_probability = torch.zeros(batch_size, 3) tmp_start_probability = torch.zeros(batch_size, paragraph_size) loss = 0 # store the predict logits for the whole lists category_predict_logits_after_list = torch.rand( batch_size, list_size, 3) best_span_after_list = torch.rand(batch_size, list_size, 2) for index in range(list_size): # get one slice of step for prediction embedding_paragraph_slice = embedding_paragraph[:, index, :, :].squeeze( 1) para_mask_slice = para_mask[:, index, :].squeeze(1) para_lstm_mask_slice = para_mask_slice if self._mask_lstms else None para_index_mask_slice = para_index_mask[:, index] after_category_mask_slice = after_category_mask_list[:, index, :].squeeze( ) # bi-LSTM: generate the contextual embeddings for the current step # size: batchsize * paragraph_size * modeling_layer_hidden_size encoded_paragraph = self._dropout( self._modeling_layer(embedding_paragraph_slice, para_lstm_mask_slice)) # max-pooling output for three category classification category_input, category_input_indices = torch.max( encoded_paragraph, 1) modeling_dim = encoded_paragraph.size(-1) span_start_input = encoded_paragraph # predict the initial before location state if index == 0: # three category classification for initial before location category_predict_logits_before = self._category_before_predictor( category_input) tmp_category_probability = category_predict_logits_before '''Model the before_loc prediction''' # predict the initial before location start scores # shape: batchsize * paragraph_size span_start_logits_before = self._span_start_predictor_before( span_start_input).squeeze(-1) # shape: batchsize * paragraph_size span_start_probs_before = util.masked_softmax( span_start_logits_before, para_mask_slice) tmp_start_probability = span_start_probs_before # shape: batchsize * hiddensize span_start_representation_before = util.weighted_sum( encoded_paragraph, span_start_probs_before) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation_before = span_start_representation_before.unsqueeze( 1).expand(batch_size, paragraph_size, modeling_dim) # incorporate the original contextual embeddings and weighted sum vector from location start prediction # shape: batchsize * paragraph_size * 2hiddensize span_end_representation_before = torch.cat( [encoded_paragraph, tiled_start_representation_before], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end_before = self._dropout( self._span_end_encoder_before( span_end_representation_before, para_lstm_mask_slice)) # initial before location end prediction encoded_span_end_before = torch.cat( [encoded_paragraph, encoded_span_end_before], dim=-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_logits_before = self._span_end_predictor_before( encoded_span_end_before).squeeze(-1) span_end_probs_before = util.masked_softmax( span_end_logits_before, para_mask_slice) # best_span_bef = self._get_best_span(span_start_logits_bef, span_end_logits_bef) best_span_before, best_span_before_start, best_span_before_end, best_span_before_real = \ self._get_best_span_single_extend(span_start_logits_before, span_end_logits_before, category_predict_logits_before, before_category_mask) # compute the loss for initial bef location three-category classification before_null_pred = softmax(category_predict_logits_before) before_null_pred_values, before_null_pred_indices = torch.max( before_null_pred, 1) loss += nll_loss(before_null_pred, before_category.squeeze(-1)) # compute the loss for initial bef location start/end prediction before_loc_start_pred = util.masked_softmax( span_start_logits_before, para_mask_slice) logpy_before_start = torch.gather( before_loc_start_pred, 1, before_loc_start).view(-1).float() before_category_mask = before_category_mask.float() loss += -(logpy_before_start * before_category_mask).mean() before_loc_end_pred = util.masked_softmax( span_end_logits_before, para_mask_slice) logpy_before_end = torch.gather(before_loc_end_pred, 1, before_loc_end).view(-1) loss += -(logpy_before_end * before_category_mask).mean() # get the real predicted location spans # convert category output (Null and Unk) into spans ((-2,-2) or (-1, -1)) before_loc_start_real = self._get_real_spans_extend( before_loc_start, before_category, before_category_mask) before_loc_end_real = self._get_real_spans_extend( before_loc_end, before_category, before_category_mask) true_span_before = torch.stack( [before_loc_start_real, before_loc_end_real], dim=-1) true_span_before = true_span_before.squeeze(1) # input for (after location) three category classification category_input_after = torch.cat( (category_input, tmp_category_probability), dim=1) category_predict_logits_after = self._category_after_predictor( category_input_after) tmp_category_probability = category_predict_logits_after # copy the predict logits for the index of the list category_predict_logits_after_tmp = category_predict_logits_after.unsqueeze( 1) category_predict_logits_after_list[:, index, :] = category_predict_logits_after_tmp.data ''' Model the after_loc prediction ''' # after location start prediction: takes contextual embeddings and weighted sum vector as input # shape: batchsize * hiddensize prev_start = util.weighted_sum(category_input, tmp_start_probability) tiled_prev_start = prev_start.unsqueeze(1).expand( batch_size, paragraph_size, modeling_dim) span_start_input_after = torch.cat( (span_start_input, tiled_prev_start), dim=2) encoded_start_input_after = self._dropout( self._span_start_encoder_after(span_start_input_after, para_lstm_mask_slice)) span_start_input_after_cat = torch.cat( [encoded_paragraph, encoded_start_input_after], dim=-1) # predict the after location start span_start_logits_after = self._span_start_predictor_after( span_start_input_after_cat).squeeze(-1) # shape: batchsize * paragraph_size span_start_probs_after = util.masked_softmax( span_start_logits_after, para_mask_slice) tmp_start_probability = span_start_probs_after # after location end prediction: takes contextual embeddings and weight sum vector as input # shape: batchsize * hiddensize span_start_representation_after = util.weighted_sum( encoded_paragraph, span_start_probs_after) # Tensor Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation_after = span_start_representation_after.unsqueeze( 1).expand(batch_size, paragraph_size, modeling_dim) # shape: batchsize * paragraph_size * 2hiddensize span_end_representation_after = torch.cat( [encoded_paragraph, tiled_start_representation_after], dim=-1) # Tensor Shape: (batch_size, passage_length, encoding_dim) encoded_span_end_after = self._dropout( self._span_end_encoder_after(span_end_representation_after, para_lstm_mask_slice)) encoded_span_end_after = torch.cat( [encoded_paragraph, encoded_span_end_after], dim=-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_logits_after = self._span_end_predictor_after( encoded_span_end_after).squeeze(-1) span_end_probs_after = util.masked_softmax(span_end_logits_after, para_mask_slice) # get the best span for after location prediction best_span_after, best_span_after_start, best_span_after_end, best_span_after_real = \ self._get_best_span_single_extend(span_start_logits_after, span_end_logits_after, category_predict_logits_after, after_category_mask_slice) # copy current best span to the list for final evaluation best_span_after_list[:, index, :] = best_span_after.data.view( batch_size, 1, 2) """ Compute the Loss for this slice """ after_category_mask = after_category_mask_slice.float().squeeze( -1) # batchsize after_category_slice = after_category_list[:, index, :] # batchsize * 1 after_loc_start_slice = after_loc_start_list[:, index, :] after_loc_end_slice = after_loc_end_list[:, index, :] # compute the loss for (after location) three category classification para_index_mask_slice_tiled = para_index_mask_slice.unsqueeze( 1).expand(para_index_mask_slice.size(0), 3) after_category_pred = util.masked_softmax( category_predict_logits_after, para_index_mask_slice_tiled) logpy_after_category = torch.gather(after_category_pred, 1, after_category_slice).view(-1) loss += -(logpy_after_category * para_index_mask_slice).mean() # compute the loss for location start/end prediction after_loc_start_pred = util.masked_softmax(span_start_logits_after, para_mask_slice) logpy_after_start = torch.gather(after_loc_start_pred, 1, after_loc_start_slice).view(-1) loss += -(logpy_after_start * after_category_mask).mean() after_loc_end_pred = util.masked_softmax(span_end_logits_after, para_mask_slice) logpy_after_end = torch.gather(after_loc_end_pred, 1, after_loc_end_slice).view(-1) loss += -(logpy_after_end * after_category_mask).mean() # for evaluation (combine the all annotations) after_loc_start_real = self._get_real_spans_extend_list( after_loc_start_list, after_category_list, after_category_mask_list) after_loc_end_real = self._get_real_spans_extend_list( after_loc_end_list, after_category_list, after_category_mask_list) true_span_after = torch.stack( [after_loc_start_real, after_loc_end_real], dim=-1) true_span_after = true_span_after.squeeze(2) best_span_after_list = Variable(best_span_after_list) true_span_after = true_span_after.view( true_span_after.size(0) * true_span_after.size(1), true_span_after.size(2)).float() para_index_mask_tiled = para_index_mask.view(-1, 1) para_index_mask_tiled = para_index_mask_tiled.expand( para_index_mask_tiled.size(0), 2) para_index_mask_tiled2 = para_index_mask.unsqueeze(2).expand( para_index_mask.size(0), para_index_mask.size(1), 2) after_category_mask_list_tiled = after_category_mask_list.expand( batch_size, list_size, 2) after_category_mask_list_tiled = after_category_mask_list_tiled * para_index_mask_tiled2.long( ) # merge all the best spans predicted for the current batch, filter out the padded instances merged_sys_span, merged_gold_span = self._get_merged_spans( true_span_before, best_span_before, true_span_after, best_span_after_list, para_index_mask_tiled) output_dict = {} output_dict["best_span"] = merged_sys_span.view( 1, merged_sys_span.size(0) * merged_sys_span.size(1)) output_dict["true_span"] = merged_gold_span.view( 1, merged_gold_span.size(0) * merged_gold_span.size(1)) output_dict["loss"] = loss return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: batch_size, num_of_passage_tokens = passage['bert'].size() # BERT for QA is a fully connected linear layer on top of BERT producing 2 vectors of # start and end spans. embedded_passage = self._text_field_embedder(passage) passage_length = embedded_passage.size(1) logits = self.qa_outputs(embedded_passage) start_logits, end_logits = logits.split(1, dim=-1) span_start_logits = start_logits.squeeze(-1) span_end_logits = end_logits.squeeze(-1) # Adding some masks with numerically stable values passage_mask = util.get_text_field_mask(passage).float() repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, 1, 1) repeated_passage_mask = repeated_passage_mask.view( batch_size, passage_length) span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7) output_dict: Dict[str, Any] = {} # add span start and end logits for knowledge distillation output_dict: Dict[str, Any] = { "span_start_logits": span_start_logits, "span_end_logits": span_end_logits, } # We may have multiple instances per questions, moving to per-question intances_question_id = [ insta_meta['question_id'] for insta_meta in metadata ] question_instances_split_inds = np.cumsum( np.unique(intances_question_id, return_counts=True)[1])[:-1] per_question_inds = np.split(range(batch_size), question_instances_split_inds) metadata = np.split(metadata, question_instances_split_inds) # Compute the loss. # if span_start is not None and len(np.argwhere(span_start.squeeze().cpu() >= 0)) > 0: if span_start is not None and len( np.argwhere( span_start.squeeze(-1).squeeze(-1).cpu() >= 0)) > 0: # in evaluation some instances may not contain the gold answer, so we need to compute # loss only on those that do. inds_with_gold_answer = np.argwhere( span_start.view(-1).cpu().numpy() >= 0) inds_with_gold_answer = inds_with_gold_answer.squeeze( ) if len(inds_with_gold_answer) > 1 else inds_with_gold_answer if len(inds_with_gold_answer) > 0: loss = nll_loss(util.masked_log_softmax(span_start_logits[inds_with_gold_answer], \ repeated_passage_mask[inds_with_gold_answer]),\ span_start.view(-1)[inds_with_gold_answer], ignore_index=-1) output_dict["loss_start"] = loss loss += nll_loss(util.masked_log_softmax(span_end_logits[inds_with_gold_answer], \ repeated_passage_mask[inds_with_gold_answer]),\ span_end.view(-1)[inds_with_gold_answer], ignore_index=-1) output_dict["loss"] = loss output_dict["loss_end"] = loss - output_dict["loss_start"] # This is a hack for cases in which gold answer is not provided so we cannot compute loss... if 'loss' not in output_dict: output_dict["loss"] = torch.cuda.FloatTensor([0], device=span_end_logits.device) \ if torch.cuda.is_available() else torch.FloatTensor([0]) # Compute F1 and preparing the output dictionary. output_dict['best_span_str'] = [] output_dict['qid'] = [] output_dict["start_bias_weight"] = [] output_dict["end_bias_weight"] = [] # getting best span prediction for best_span = self._get_example_predications(span_start_logits, span_end_logits, self._max_span_length) best_span_cpu = best_span.detach().cpu().numpy() span_start_logits_numpy = span_start_logits.data.cpu().numpy() span_end_logits_numpy = span_end_logits.data.cpu().numpy() # Iterating over every question (which may contain multiple instances, one per chunk) for question_inds, question_instances_metadata in zip( per_question_inds, metadata): best_span_ind = np.argmax( span_start_logits_numpy[question_inds, best_span_cpu[question_inds][:, 0]] + span_end_logits_numpy[question_inds, best_span_cpu[question_inds][:, 1]]) best_span_logit = np.max( span_start_logits_numpy[question_inds, best_span_cpu[question_inds][:, 0]] + span_end_logits_numpy[question_inds, best_span_cpu[question_inds][:, 1]]) passage_str = question_instances_metadata[best_span_ind][ 'original_passage'] offsets = question_instances_metadata[best_span_ind][ 'token_offsets'] predicted_span = best_span_cpu[question_inds[best_span_ind]] start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] # Note: this is a hack, because AllenNLP, when predicting, expects a value for each instance. # But we may have more than 1 chunk per question, and thus less output strings than instances for i in range(len(question_inds)): output_dict['best_span_str'].append(best_span_string) output_dict['qid'].append( question_instances_metadata[best_span_ind]['question_id']) # get the scalar logit value of the predicted span start and end index as bias weight. output_dict["start_bias_weight"].append( util.masked_softmax(span_start_logits[best_span_ind], repeated_passage_mask[best_span_ind])[ best_span_cpu[best_span_ind][0]]) output_dict["end_bias_weight"].append( util.masked_softmax(span_end_logits[best_span_ind], repeated_passage_mask[best_span_ind])[ best_span_cpu[best_span_ind][1]]) f1_score = 0.0 EM_score = 0.0 gold_answer_texts = question_instances_metadata[best_span_ind][ 'answer_texts_list'] if gold_answer_texts: f1_score = squad_eval.metric_max_over_ground_truths( squad_eval.f1_score, best_span_string, gold_answer_texts) EM_score = squad_eval.metric_max_over_ground_truths( squad_eval.exact_match_score, best_span_string, gold_answer_texts) self._official_f1(100 * f1_score) self._official_EM(100 * EM_score) # TODO move to predict if self._predictions_file is not None: with open(self._predictions_file, 'a') as f: f.write(json.dumps({'question_id':question_instances_metadata[best_span_ind]['question_id'], \ 'best_span_logit':float(best_span_logit), \ 'f1':100 * f1_score, 'EM':100 * EM_score, 'best_span_string':best_span_string,\ 'gold_answer_texts':gold_answer_texts, \ 'qas_used_fraction':1.0}) + '\n') return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer( self._text_field_embedder(question)) embedded_passage = self._highway_layer( self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None # encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask)) # # v5: # # remember to set token embeddings in the CONFIG JSON encoded_question = self._dropout(embedded_question) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) -- SIMILARITY MATRIX similarity_matrix = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) -- CONTEXT2QUERY passage_question_attention = util.last_dim_softmax( similarity_matrix, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # Our custom query2context q2c_attention = util.masked_softmax(similarity_matrix, question_mask, dim=1).transpose(-1, -2) q2c_vecs = util.weighted_sum(encoded_passage, q2c_attention) # Now we try the various variants # v1: # tiled_question_passage_vector = util.weighted_sum(q2c_vecs, passage_question_attention) # v2: # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], encoded_passage.shape[1])) # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).transpose(-1, -2) # v3: # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], 1)) # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).squeeze().unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # v4: # Re-application of query2context attention # new_similarity_matrix = self._matrix_attention(encoded_passage, q2c_vecs) # masked_similarity = util.replace_masked_values(new_similarity_matrix, # question_mask.unsqueeze(1), # -1e7) # # Shape: (batch_size, passage_length) # question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # # Shape: (batch_size, passage_length) # question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # # Shape: (batch_size, encoding_dim) # question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # # Shape: (batch_size, passage_length, encoding_dim) # tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, # passage_length, # encoding_dim) # ------- Original variant # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( similarity_matrix, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) # ------- END # Shape: (batch_size, passage_length, encoding_dim * 4) # original beta combination function final_merged_passage = torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector ], dim=-1) # # v6: # final_merged_passage = torch.cat([tiled_question_passage_vector], # dim=-1) # # # v7: # final_merged_passage = torch.cat([passage_question_vectors], # dim=-1) # # # v8: # final_merged_passage = torch.cat([passage_question_vectors, # tiled_question_passage_vector], # dim=-1) # # # v9: # final_merged_passage = torch.cat([encoded_passage, # passage_question_vectors, # encoded_passage * passage_question_vectors], # dim=-1) modeled_passage = self._dropout( self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout( torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze( 1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([ final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation ], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout( self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout( torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, get_sample_level_information = True) -> Dict[str, torch.Tensor]: """ WE LOAD THE MODELS ONE INTO GPU ONE AT A TIME !!! """ subresults = [] for submodel in self.submodels: submodel.to(device = submodel.cf_a.device) subres = submodel(question, passage, span_start, span_end, metadata, get_sample_level_information) submodel.to(device = torch.device("cpu")) subresults.append(subres) batch_size = len(subresults[0]["best_span"]) best_span = merge_span_probs(subresults) output = { "best_span": best_span, "best_span_str": [], "models_output": subresults } if (get_sample_level_information): output["em_samples"] = [] output["f1_samples"] = [] for index in range(batch_size): if metadata is not None: passage_str = metadata[index]['original_passage'] offsets = metadata[index]['token_offsets'] predicted_span = tuple(best_span[index].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output["best_span_str"].append(best_span_string) answer_texts = metadata[index].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) if (get_sample_level_information): em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts) output["em_samples"].append(em_sample) output["f1_samples"].append(f1_sample) if (get_sample_level_information): # Add information about the individual samples for future analysis output["span_start_sample_loss"] = [] output["span_end_sample_loss"] = [] for i in range (batch_size): span_start_probs = sum(subresult['span_start_probs'] for subresult in subresults) / len(subresults) span_end_probs = sum(subresult['span_end_probs'] for subresult in subresults) / len(subresults) span_start_loss = nll_loss(span_start_probs[[i],:], span_start.squeeze(-1)[[i]]) span_end_loss = nll_loss(span_end_probs[[i],:], span_end.squeeze(-1)[[i]]) output["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy())) output["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy())) return output