def forward(self, # type: ignore metadata: Dict, tokens: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : Dict[str, torch.LongTensor] From a ``TextField`` (that has a bert-pretrained token indexer) span_start : torch.IntTensor, optional (default = None) A tensor of shape (batch_size, 1) which contains the start_position of the answer in the passage, or 0 if impossible. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : torch.IntTensor, optional (default = None) A tensor of shape (batch_size, 1) which contains the end_position of the answer in the passage, or 0 if impossible. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. Returns ------- An output dictionary consisting of: logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalized log probabilities of the label. start_probs: torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the label. end_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the label. best_span: loss : torch.FloatTensor, optional A scalar loss to be optimised. """ input_ids = tokens[self._index] token_type_ids = tokens[f"{self._index}-type-ids"] input_mask = (input_ids != 0).long() # 1. Build model here bert_output, _ = self.bert_model(input_ids, token_type_ids, attention_mask=input_mask) linear_output = self.linear(bert_output) linear_dropped = self.drop(linear_output) start_logits, end_logits = linear_dropped.split(1, dim=-1) start_logits, end_logits = start_logits.squeeze(-1), end_logits.squeeze(-1) # 2. Compute start_position and end_position and then get the best span # using allennlp.models.reading_comprehension.util.get_best_span() masked_soft_start = masked_softmax(start_logits, mask=input_mask) masked_soft_end = masked_softmax(end_logits, mask=input_mask) best_span = get_best_span(masked_soft_start, masked_soft_end) output_dict = { "start_logits": start_logits, "end_logits": end_logits, "start_probs": masked_soft_start, "end_probs": masked_soft_end, "best_span": best_span } # 4. Compute loss and accuracies. You should compute at least: # span_start accuracy, span_end accuracy and full span accuracy. # import ipdb;ipdb.set_trace() self._span_start_accuracy(start_logits, span_start.squeeze()) self._span_end_accuracy(end_logits, span_end.squeeze()) self._span_accuracy(best_span, torch.stack([span_start.squeeze(), span_end.squeeze()])) # UNCOMMENT THIS LINE # import ipdb;ipdb.set_trace() if span_start is not None: ignored_index = start_logits.size(1) span_start.clamp_(0, ignored_index) span_end.clamp_(0, ignored_index) start_loss = self.loss(start_logits, span_start.squeeze(-1)) end_loss = self.loss(end_logits, span_end.squeeze(-1)) combined_loss = (start_loss + end_loss) / 2 output_dict["loss"] = combined_loss # 5. Optionally you can compute the official squad metrics (exact match, f1). # Instantiate the metric object in __init__ using allennlp.training.metrics.SquadEmAndF1() # When you call it, you need to give it the word tokens of the span (implement and call decode() below) # and the gold tokens found in metadata[i]['answer_texts'] return output_dict
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_starts: torch.IntTensor = None, span_ends: torch.IntTensor = None, yesno_labels : torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: batch_size, num_of_passage_tokens = passage['bert'].size() # Executing the BERT model on the word piece ids (input_ids) input_ids = passage['bert'] token_type_ids = torch.zeros_like(input_ids) mask = (input_ids != 0).long() embedded_chunk, pooled_output = \ self._text_field_embedder.token_embedder_bert.bert_model(input_ids=util.combine_initial_dims(input_ids), token_type_ids=util.combine_initial_dims(token_type_ids), attention_mask=util.combine_initial_dims(mask), output_all_encoded_layers=False) # Just measuring some lengths and offsets to handle the converstion between tokens and word-pieces passage_length = embedded_chunk.size(1) mask_min_values, wordpiece_passage_lens = torch.min(mask, dim=1) wordpiece_passage_lens[mask_min_values == 1] = mask.shape[1] offset_min_values, token_passage_lens = torch.min(passage['bert-offsets'], dim=1) token_passage_lens[offset_min_values != 0] = passage['bert-offsets'].shape[1] bert_offsets = passage['bert-offsets'].cpu().numpy() # BERT for QA is a fully connected linear layer on top of BERT producing 2 vectors of # start and end spans. logits = self.qa_outputs(embedded_chunk) start_logits, end_logits = logits.split(1, dim=-1) span_start_logits = start_logits.squeeze(-1) span_end_logits = end_logits.squeeze(-1) # all input is preprocessed before farword is run, counting the yesno vocabulary # will indicate if yesno support is at all needed. if self.vocab.get_vocab_size("yesno_labels") > 1: yesno_logits = self.qa_yesno(torch.max(embedded_chunk, 1)[0]) span_starts.clamp_(0, passage_length) span_ends.clamp_(0, passage_length) # moving to word piece indexes from token indexes of start and end span span_starts_list = [bert_offsets[i, span_starts[i]] if span_starts[i] != 0 else 0 for i in range(batch_size)] span_ends_list = [bert_offsets[i, span_ends[i]] if span_ends[i] != 0 else 0 for i in range(batch_size)] span_starts = torch.cuda.LongTensor(span_starts_list, device=span_end_logits.device) \ if torch.cuda.is_available() else torch.LongTensor(span_starts_list) span_ends = torch.cuda.LongTensor(span_ends_list, device=span_end_logits.device) \ if torch.cuda.is_available() else torch.LongTensor(span_ends_list) loss_fct = CrossEntropyLoss(ignore_index=passage_length) start_loss = loss_fct(start_logits.squeeze(-1), span_starts) end_loss = loss_fct(end_logits.squeeze(-1), span_ends) if self.vocab.get_vocab_size("yesno_labels") > 1 and yesno_labels is not None: yesno_loss = loss_fct(yesno_logits, yesno_labels) loss = (start_loss + end_loss + yesno_loss) / 3 else: loss = (start_loss + end_loss) / 2 output_dict: Dict[str, Any] = {} if loss == 0: # For evaluation purposes only! output_dict["loss"] = torch.cuda.FloatTensor([0], device=span_end_logits.device) \ if torch.cuda.is_available() else torch.FloatTensor([0]) else: output_dict["loss"] = loss # Compute F1 and preparing the output dictionary. output_dict['best_span_str'] = [] output_dict['best_span_logit'] = [] output_dict['cannot_answer_logit'] = [] output_dict['yesno'] = [] output_dict['yesno_logit'] = [] output_dict['qid'] = [] if span_starts is not None: output_dict['EM'] = [] output_dict['f1'] = [] # getting best span prediction for best_span = self._get_example_predications(span_start_logits, span_end_logits, self._max_span_length) best_span_cpu = best_span.detach().cpu().numpy() for instance_ind, instance_metadata in zip(range(batch_size), metadata): best_span_logit = span_start_logits.data.cpu().numpy()[instance_ind, best_span_cpu[instance_ind][0]] + \ span_end_logits.data.cpu().numpy()[instance_ind, best_span_cpu[instance_ind][1]] cannot_answer_logit = span_start_logits.data.cpu().numpy()[instance_ind, 0] + \ span_end_logits.data.cpu().numpy()[instance_ind, 0] if self.vocab.get_vocab_size("yesno_labels") > 1: yesno_maxind = np.argmax(yesno_logits[instance_ind].data.cpu().numpy()) yesno_logit = yesno_logits[instance_ind, yesno_maxind].data.cpu().numpy() yesno_pred = self.vocab.get_token_from_index(yesno_maxind, namespace="yesno_labels") else: yesno_pred = 'no_yesno' yesno_logit = -30.0 passage_str = instance_metadata['original_passage'] offsets = instance_metadata['token_offsets'] predicted_span = best_span_cpu[instance_ind] # In this version yesno if not "no_yesno" will be regarded as final answer before the spans are considered. if yesno_pred != 'no_yesno': best_span_string = yesno_pred else: if cannot_answer_logit + 0.9 > best_span_logit : best_span_string = 'cannot_answer' else: wordpiece_offsets = self.bert_offsets_to_wordpiece_offsets(bert_offsets[instance_ind][0:len(offsets)]) start_offset = offsets[wordpiece_offsets[predicted_span[0] if predicted_span[0] < len(wordpiece_offsets) \ else len(wordpiece_offsets)-1]][0] end_offset = offsets[wordpiece_offsets[predicted_span[1] if predicted_span[1] < len(wordpiece_offsets) \ else len(wordpiece_offsets)-1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) output_dict['cannot_answer_logit'].append(cannot_answer_logit) output_dict['best_span_logit'].append(best_span_logit) output_dict['yesno'].append(yesno_pred) output_dict['yesno_logit'].append(yesno_logit) output_dict['qid'].append(instance_metadata['question_id']) # In AllenNLP prediction mode we have no gold answers, so let's check if span_starts is not None: yesno_label_ind = yesno_labels.data.cpu().numpy()[instance_ind] yesno_label = self.vocab.get_token_from_index(yesno_label_ind, namespace="yesno_labels") if yesno_label != 'no_yesno': gold_answer_texts = [yesno_label] elif instance_metadata['cannot_answer']: gold_answer_texts = ['cannot_answer'] else: gold_answer_texts = instance_metadata['answer_texts_list'] f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, gold_answer_texts) EM_score = squad_eval.metric_max_over_ground_truths(squad_eval.exact_match_score, best_span_string, gold_answer_texts) self._official_f1(100 * f1_score) self._official_EM(100 * EM_score) output_dict['EM'].append(100 * EM_score) output_dict['f1'].append(100 * f1_score) return output_dict