def lens_to_mask(lens: torch.IntTensor) -> torch.Tensor: """ Create a 2-D mask tensor of shape (batch_size, max_length) and dtype float32 from a 1-D tensor of integers describing the length of batch samples in another tensor. """ mask = lens.new_zeros(lens.shape[0], max(lens), dtype=torch.float32) for i, num in enumerate(lens): mask[i, :num] = 1.0 return mask
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, sentence_spans: torch.IntTensor = None, sent_labels: torch.IntTensor = None, evd_chain_labels: torch.IntTensor = None, q_type: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: if self._sent_labels_src == 'chain': batch_size, num_spans = sent_labels.size() sent_labels_mask = (sent_labels >= 0).float() print("chain:", evd_chain_labels) # we use the chain as the label to supervise the gate # In this model, we only take the first chain in ``evd_chain_labels`` for supervision, # right now the number of chains should only be one too. evd_chain_labels = evd_chain_labels[:, 0].long() # build the gate labels. The dim is set to 1 + num_spans to account for the end embedding # shape: (batch_size, 1+num_spans) sent_labels = sent_labels.new_zeros((batch_size, 1+num_spans)) sent_labels.scatter_(1, evd_chain_labels, 1.) # remove the column for end embedding # shape: (batch_size, num_spans) sent_labels = sent_labels[:, 1:].float() # make the padding be -1 sent_labels = sent_labels * sent_labels_mask + -1. * (1 - sent_labels_mask) print('\nBert wordpiece size:', passage['bert'].shape) # bert embedding for answer prediction # shape: [batch_size, max_q_len, emb_size] embedded_question = self._text_field_embedder(question, num_wrapping_dims=0) # shape: [batch_size, num_sent, max_sent_len+q_len, embedding_dim] embedded_passage = self._text_field_embedder(passage, num_wrapping_dims=1) # print('\npassage size:', embedded_passage.shape) #embedded_question = self._bert_projection(embedded_question) #embedded_passage = self._bert_projection(embedded_passage) #print('size embedded_passage:', embedded_passage.shape) # mask ques_mask = util.get_text_field_mask(question, num_wrapping_dims=0).float() context_mask = util.get_text_field_mask(passage, num_wrapping_dims=1).float() # gate prediction # Shape(gate_logit): (batch_size * num_spans, 2) # Shape(gate): (batch_size * num_spans, 1) # Shape(pred_sent_probs): (batch_size * num_spans, 2) # Shape(gate_mask): (batch_size, num_spans) #gate_logit, gate, pred_sent_probs = self._span_gate(spans_rep_sp, spans_mask) gate_logit, gate, pred_sent_probs, gate_mask, g_att_score = self._span_gate(embedded_passage, context_mask, self._gate_self_attention_layer, self._gate_sent_encoder) batch_size, num_spans, max_batch_span_width = context_mask.size() loss = F.nll_loss(F.log_softmax(gate_logit, dim=-1).view(batch_size * num_spans, -1), sent_labels.long().view(batch_size * num_spans), ignore_index=-1) gate = (gate >= 0.3).long() gate = gate.view(batch_size, num_spans) output_dict = { "pred_sent_labels": gate, #[B, num_span] "gate_probs": pred_sent_probs[:, 1].view(batch_size, num_spans), #[B, num_span] } if self._output_att_scores: if not g_att_score is None: output_dict['evd_self_attention_score'] = g_att_score # Compute the loss for training. try: #loss = strong_sup_loss self._loss_trackers['loss'](loss) output_dict["loss"] = loss except RuntimeError: print('\n meta_data:', metadata) print(span_start_logits.shape) print("sent label:") for b_label in np.array(sent_labels.cpu()): b_label = b_label == 1 indices = np.arange(len(b_label)) print(indices[b_label] + 1) # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['answer_texts'] = [] question_tokens = [] passage_tokens = [] #token_spans_sp = [] #token_spans_sent = [] sent_labels_list = [] evd_possible_chains = [] ans_sent_idxs = [] ids = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_sent_tokens']) #token_spans_sp.append(metadata[i]['token_spans_sp']) #token_spans_sent.append(metadata[i]['token_spans_sent']) sent_labels_list.append(metadata[i]['sent_labels']) ids.append(metadata[i]['_id']) passage_str = metadata[i]['original_passage'] #offsets = metadata[i]['token_offsets'] answer_texts = metadata[i].get('answer_texts', []) output_dict['answer_texts'].append(answer_texts) # shift sentence indice back evd_possible_chains.append([s_idx-1 for s_idx in metadata[i]['evd_possible_chains'][0] if s_idx > 0]) ans_sent_idxs.append([s_idx-1 for s_idx in metadata[i]['ans_sent_idxs']]) if len(metadata[i]['ans_sent_idxs']) > 0: pred_sent_gate = gate[i].detach().cpu().numpy() if any([pred_sent_gate[s_idx-1] > 0 for s_idx in metadata[i]['ans_sent_idxs']]): self.evd_ans_metric(1) else: self.evd_ans_metric(0) self._f1_metrics(pred_sent_probs, sent_labels.view(-1), gate_mask.view(-1)) output_dict['question_tokens'] = question_tokens output_dict['passage_sent_tokens'] = passage_tokens #output_dict['token_spans_sp'] = token_spans_sp #output_dict['token_spans_sent'] = token_spans_sent output_dict['sent_labels'] = sent_labels_list output_dict['evd_possible_chains'] = evd_possible_chains output_dict['ans_sent_idxs'] = ans_sent_idxs output_dict['_id'] = ids return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, sentence_spans: torch.IntTensor = None, sent_labels: torch.IntTensor = None, evd_chain_labels: torch.IntTensor = None, q_type: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: if self._sent_labels_src == 'chain': batch_size, num_spans = sent_labels.size() sent_labels_mask = (sent_labels >= 0).float() print("chain:", evd_chain_labels) # we use the chain as the label to supervise the gate # In this model, we only take the first chain in ``evd_chain_labels`` for supervision, # right now the number of chains should only be one too. evd_chain_labels = evd_chain_labels[:, 0].long() # build the gate labels. The dim is set to 1 + num_spans to account for the end embedding # shape: (batch_size, 1+num_spans) sent_labels = sent_labels.new_zeros((batch_size, 1 + num_spans)) sent_labels.scatter_(1, evd_chain_labels, 1.) # remove the column for end embedding # shape: (batch_size, num_spans) sent_labels = sent_labels[:, 1:].float() # make the padding be -1 sent_labels = sent_labels * sent_labels_mask + -1. * ( 1 - sent_labels_mask) # word + char embedding embedded_question = self._text_field_embedder(question) embedded_passage = self._text_field_embedder(passage) # mask ques_mask = util.get_text_field_mask(question).float() context_mask = util.get_text_field_mask(passage).float() # BiDAF for answer predicion ques_output = self._dropout( self._phrase_layer(embedded_question, ques_mask)) context_output = self._dropout( self._phrase_layer(embedded_passage, context_mask)) modeled_passage, _, qc_score = self.qc_att(context_output, ques_output, ques_mask) modeled_passage = self._modeling_layer(modeled_passage, context_mask) # BiDAF for gate prediction ques_output_sp = self._dropout( self._phrase_layer_sp(embedded_question, ques_mask)) context_output_sp = self._dropout( self._phrase_layer_sp(embedded_passage, context_mask)) modeled_passage_sp, _, qc_score_sp = self.qc_att_sp( context_output_sp, ques_output_sp, ques_mask) modeled_passage_sp = self._modeling_layer_sp(modeled_passage_sp, context_mask) # gate prediction # Shape(spans_rep): (batch_size * num_spans, max_batch_span_width, embedding_dim) # Shape(spans_mask): (batch_size, num_spans, max_batch_span_width) spans_rep_sp, spans_mask = convert_sequence_to_spans( modeled_passage_sp, sentence_spans) spans_rep, _ = convert_sequence_to_spans(modeled_passage, sentence_spans) # Shape(gate_logit): (batch_size * num_spans, 2) # Shape(gate): (batch_size * num_spans, 1) # Shape(pred_sent_probs): (batch_size * num_spans, 2) # Shape(gate_mask): (batch_size, num_spans) #gate_logit, gate, pred_sent_probs = self._span_gate(spans_rep_sp, spans_mask) gate_logit, gate, pred_sent_probs, gate_mask, g_att_score = self._span_gate( spans_rep_sp, spans_mask, self._gate_self_attention_layer, self._gate_sent_encoder) batch_size, num_spans, max_batch_span_width = spans_mask.size() strong_sup_loss = F.nll_loss( F.log_softmax(gate_logit, dim=-1).view(batch_size * num_spans, -1), sent_labels.long().view(batch_size * num_spans), ignore_index=-1) gate = (gate >= 0.3).long() spans_rep = spans_rep * gate.unsqueeze(-1).float() attended_sent_embeddings = convert_span_to_sequence( modeled_passage_sp, spans_rep, spans_mask) modeled_passage = attended_sent_embeddings + modeled_passage self_att_passage = self._self_attention_layer(modeled_passage, mask=context_mask) modeled_passage = modeled_passage + self_att_passage[0] self_att_score = self_att_passage[2] output_start = self._span_start_encoder(modeled_passage, context_mask) span_start_logits = self.linear_start(output_start).squeeze( 2) - 1e30 * (1 - context_mask) output_end = torch.cat([modeled_passage, output_start], dim=2) output_end = self._span_end_encoder(output_end, context_mask) span_end_logits = self.linear_end(output_end).squeeze( 2) - 1e30 * (1 - context_mask) output_type = torch.cat([modeled_passage, output_end, output_start], dim=2) output_type = torch.max(output_type, 1)[0] # output_type = torch.max(self.rnn_type(output_type, context_mask), 1)[0] predict_type = self.linear_type(output_type) type_predicts = torch.argmax(predict_type, 1) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "span_start_logits": span_start_logits, "span_end_logits": span_end_logits, "best_span": best_span, "pred_sent_labels": gate.view(batch_size, num_spans), #[B, num_span] "gate_probs": pred_sent_probs[:, 1].view(batch_size, num_spans), #[B, num_span] } if self._output_att_scores: if not qc_score is None: output_dict['qc_score'] = qc_score if not qc_score_sp is None: output_dict['qc_score_sp'] = qc_score_sp if not self_att_score is None: output_dict['self_attention_score'] = self_att_score if not g_att_score is None: output_dict['evd_self_attention_score'] = g_att_score print("sent label:") for b_label in np.array(sent_labels.cpu()): b_label = b_label == 1 indices = np.arange(len(b_label)) print(indices[b_label] + 1) # Compute the loss for training. if span_start is not None: try: start_loss = nll_loss( util.masked_log_softmax(span_start_logits, None), span_start.squeeze(-1)) end_loss = nll_loss( util.masked_log_softmax(span_end_logits, None), span_end.squeeze(-1)) type_loss = nll_loss( util.masked_log_softmax(predict_type, None), q_type) loss = start_loss + end_loss + type_loss + strong_sup_loss self._loss_trackers['loss'](loss) self._loss_trackers['start_loss'](start_loss) self._loss_trackers['end_loss'](end_loss) self._loss_trackers['type_loss'](type_loss) self._loss_trackers['strong_sup_loss'](strong_sup_loss) output_dict["loss"] = loss except RuntimeError: print('\n meta_data:', metadata) print(span_start_logits.shape) # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] output_dict['answer_texts'] = [] question_tokens = [] passage_tokens = [] token_spans_sp = [] token_spans_sent = [] sent_labels_list = [] evd_possible_chains = [] ans_sent_idxs = [] ids = [] count_yes = 0 count_no = 0 for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) token_spans_sp.append(metadata[i]['token_spans_sp']) token_spans_sent.append(metadata[i]['token_spans_sent']) sent_labels_list.append(metadata[i]['sent_labels']) ids.append(metadata[i]['_id']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] if type_predicts[i] == 1: best_span_string = 'yes' count_yes += 1 elif type_predicts[i] == 2: best_span_string = 'no' count_no += 1 else: predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) output_dict['answer_texts'].append(answer_texts) if answer_texts: self._squad_metrics(best_span_string.lower(), answer_texts) # shift sentence indice back evd_possible_chains.append([ s_idx - 1 for s_idx in metadata[i]['evd_possible_chains'][0] if s_idx > 0 ]) ans_sent_idxs.append( [s_idx - 1 for s_idx in metadata[i]['ans_sent_idxs']]) self._f1_metrics(pred_sent_probs, sent_labels.view(-1), gate_mask.view(-1)) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens output_dict['token_spans_sp'] = token_spans_sp output_dict['token_spans_sent'] = token_spans_sent output_dict['sent_labels'] = sent_labels_list output_dict['evd_possible_chains'] = evd_possible_chains output_dict['ans_sent_idxs'] = ans_sent_idxs output_dict['_id'] = ids return output_dict