def forward_qa(self, input_ids, attention_mask, start_positions, end_positions): # Do forward pass on DistilBERT outputs = self.qa_model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions, output_hidden_states=True) # Get final hidden state from DistilBERT output last_hidden_state = outputs["hidden_states"][-1] # Get output layer logits (start and end) logits = self.qa_outputs(last_hidden_state) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) total_loss = None if start_positions is not None and end_positions is not None: # Sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions.clamp_(0, ignored_index) end_positions.clamp_(0, ignored_index) # Use the final hidden state to get the targets from the discriminator model hidden = last_hidden_state[:, 0] # same as cls_embedding log_prob = self.discriminator_model(hidden) targets = torch.ones_like(log_prob) * (1 / self.num_classes) # Compute KL loss kl_criterion = nn.KLDivLoss(reduction="batchmean") kld = self.discriminator_lambda * kl_criterion(log_prob, targets) # Compute total loss by combining QA loss with KLD loss loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) qa_loss = (start_loss + end_loss) / 2 total_loss = qa_loss + kld return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
def forward(self, input_ids, attention_mask, start_positions=None, end_positions=None, labels=None, model_type=None): """ Parameters ---------- input_ids is shape [16, 384] or [batch_size, max_embedding_length] attention_mask is shape [16, 384] start_positions is shape [16, ] end_positions is shape [16, ] """ if model_type == 'qa_model': qa_loss = self.forward_qa(input_ids, attention_mask, start_positions, end_positions) return qa_loss elif model_type == 'discriminator_model': discriminator_loss = self.forward_discriminator( input_ids, attention_mask, start_positions, end_positions, labels) return discriminator_loss else: # For evaluation outputs = self.qa_model(input_ids, attention_mask=attention_mask, output_hidden_states=True) last_hidden_state = outputs["hidden_states"][-1] logits = self.qa_outputs(last_hidden_state) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) return QuestionAnsweringModelOutput(start_logits=start_logits, end_logits=end_logits)
def _process_data(self, inputs, return_dict): inp_length = inputs[self.main_input_name].shape[1] # If <max_length> specified, pad inputs by zeros if inp_length < self.max_length: for name in inputs: shape = inputs[name].shape if shape[1] != self.max_length: pad = np.zeros([len(shape), 2], dtype=np.int32) pad[1, 1] = self.max_length - shape[1] inputs[name] = np.pad(inputs[name], pad) # OpenVINO >= 2022.1 supports dynamic shapes input. if not is_openvino_api_2: inputs_info = self.net.input_info input_ids = inputs[self.main_input_name] if inputs_info[self.main_input_name].input_data.shape[ 1] != input_ids.shape[1]: # Use batch size 1 because we process batch sequently. shapes = { key: [1] + list(inputs[key].shape[1:]) for key in inputs_info } logger.info(f"Reshape model to {shapes}") self.net.reshape(shapes) self.exec_net = None elif is_openvino_api_2 and not self.use_dynamic_shapes: # TODO pass if self.exec_net is None: self._load_network() if is_openvino_api_2: outs = self._process_data_api_2022(inputs) else: outs = self._process_data_api_2021(inputs) logits = outs["output"] if "output" in outs else next( iter(outs.values())) past_key_values = None if self.config.architectures[0].endswith( "ForConditionalGeneration") and self.config.use_cache: past_key_values = [[]] for name in outs: if name == "output": continue if len(past_key_values[-1]) == 4: past_key_values.append([]) past_key_values[-1].append(torch.tensor(outs[name])) past_key_values = tuple([tuple(val) for val in past_key_values]) # Trunc padded values if inp_length != logits.shape[1]: logits = logits[:, :inp_length] if not return_dict: return [logits] arch = self.config.architectures[0] if arch.endswith("ForSequenceClassification"): return SequenceClassifierOutput(logits=logits) elif arch.endswith("ForQuestionAnswering"): return QuestionAnsweringModelOutput(start_logits=outs["output_s"], end_logits=outs["output_e"]) else: return ModelOutput(logits=torch.tensor(logits), past_key_values=past_key_values)
def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, start_positions=None, end_positions=None, return_dict=None, output_hidden_states=None, output_attentions=None, ): r""" start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, return_dict=return_dict, ) sequence_output = outputs[0] logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) # <- Answer Selection Part start_indexes = squad_metrics._get_best_indexes(start_logits.tolist(), n_best_size=41) end_indexes = squad_metrics._get_best_indexes(end_logits.tolist(), n_best_size=41) candidate_spans = (start_indexes, end_indexes) feat = self.features # spans in the original is structured like [passages, number of candidates, span of the answer] self.candidate_representation.calculate_candidate_representations( spans=candidate_spans, features=feat, seq_outpu=sequence_output) r_Ctilde = self.candidate_representation.tilda_r_Cs p_C = self.score_answers(r_Ctilde) sorted_tensor, index = torch.sort(p_C, descending=True) def helpfunction(ind, features): top1 = None placeholder = None top2 = None while top1 is None and top2 is None: print("Features[0].end_position", features[40].end_position) for n in ind: if features[ind[n]].end_position == 0 and features[ ind[n]].start_position == 0: continue elif top1 is not None: top2 = n else: top1 = n placeholder = n - 1 break if top1 is None: return 0, -1 elif top2 is None: return top1, placeholder return top1, top2 answer1, answer2 = helpfunction(index, self.features) answerdict = defaultdict(dict) for ans in [sequence_output[answer1], sequence_output[answer2]]: logits = self.qa_outputs(ans) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) total_loss = None if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions.clamp_(0, ignored_index) end_positions.clamp_(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions[index]) end_loss = loss_fct(end_logits, end_positions[index]) total_loss = (start_loss + end_loss) / 2 answerdict[ans]["loss"] = total_loss answerdict[ans]["start_logits"] = start_logits answerdict[ans]["end_logits"] = end_logits if not return_dict: output = (start_logits, end_logits) + outputs[2:] return ((total_loss, ) + output) if total_loss is not None else output if answerdict[answer1]["loss"] < answerdict[answer2]["loss"]: gold_ans = answer1 else: gold_ans = answer2 return QuestionAnsweringModelOutput( loss=answerdict[gold_ans]["loss"], start_logits=answerdict[gold_ans]["start_logits"], end_logits=answerdict[gold_ans]["end_logits"], hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): r""" start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) total_loss = None if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions.clamp_(0, ignored_index) end_positions.clamp_(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if not return_dict: output = (start_logits, end_logits) + outputs[2:] return ((total_loss,) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] logits = self.new_outputs(sequence_output) # qa_outputs start_logits, end_logits, center_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) center_logits = center_logits.squeeze(-1) # total_loss = None if start_positions is not None and end_positions is not None: # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions.clamp_(0, ignored_index) end_positions.clamp_(0, ignored_index) # center_positions.clamp_(0, ignored_index) mean_positions = torch.mean( torch.stack([start_positions, end_positions], 0).float(), 0) # print('size =', start_positions.size(), mean_positions.size()) center_positions = mean_positions.long() # round loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) center_loss = loss_fct(center_logits, center_positions) total_loss = (start_loss + end_loss + center_loss) / 3 if not return_dict: output = (start_logits, end_logits, end_logits ) + outputs[2:] # return center_logits or not ! return ((total_loss, ) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] # UNet enc_ftrs = self.encoder(sequence_output) out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:]) logits = self.head(out) logits = torch.transpose(logits, 1, 2) # logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) total_loss = None if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions.clamp_(0, ignored_index) end_positions.clamp_(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if not return_dict: output = (start_logits, end_logits) + outputs[2:] return ((total_loss,) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
def forward( self, input_ids=None, attention_mask=None, global_attention_mask=None, token_type_ids=None, position_ids=None, inputs_embeds=None, start_positions=None, end_positions=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): # [batch_size, sequence_length, hidden_size] outputs = self.longformer( input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) total_loss = None if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions.clamp_(0, ignored_index) end_positions.clamp_(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if not return_dict: output = (start_logits, end_logits) + outputs[2:] return ((total_loss, ) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): outputs = self.backbone( self.random_masking(input_ids) if (self.training and self.masking_ratio != 0.0) else input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) sequence_output = outputs[0] logits = None if self.head == "CCNN_LSTM_EM" or self.head == "CCNN_EM": exact_match_token = self.get_exact_match_token(input_ids) logits = self.qa_outputs((sequence_output, exact_match_token)) else: logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) total_loss = None if start_positions is not None and end_positions is not None: if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) ignored_index = start_logits.size(1) start_positions.clamp_(0, ignored_index) end_positions.clamp_(0, ignored_index) loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if not return_dict: output = (start_logits, end_logits) + outputs[self.pooling_pos :] return ((total_loss,) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
def forward( self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): r""" start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict distilbert_output = self.distilbert( input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = distilbert_output[0] # (bs, max_query_len, dim) hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim) logits = gelu_new(self.qa_outputs_0(hidden_states)) # (bs, max_query_len, 2) logits = gelu_new(self.qa_outputs_1(logits)) # logits = self.LayerNorm_0(logits) logits = self.qa_outputs(logits) logits = self.LayerNorm(logits) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) # (bs, max_query_len) end_logits = end_logits.squeeze(-1) # (bs, max_query_len) total_loss = None if start_positions is not None and end_positions is not None: # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions.clamp_(0, ignored_index) end_positions.clamp_(0, ignored_index) loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if not return_dict: output = (start_logits, end_logits) + distilbert_output[1:] return ((total_loss,) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=distilbert_output.hidden_states, attentions=distilbert_output.attentions )
def forward( self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): """ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict distilbert_output = self.distilbert( input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) logits = distilbert_output[0].permute(0, 2, 1) aspp_r3 = gelu_new(self.qa_aspp_r6_1x1(self.qa_aspp_r3(logits))) aspp_r6 = gelu_new(self.qa_aspp_r6_1x1(self.qa_aspp_r6(logits))) aspp_r12 = gelu_new(self.qa_aspp_r12_1x1(self.qa_aspp_r12(logits))) out_aspp = torch.cat((aspp_r3, aspp_r6, aspp_r12), 1) logits = gelu_new(self.qa_aspp_score(out_aspp)) logits = logits.unsqueeze(dim=3) logits = self.upsampling2D(logits) logits = logits[:, :, :, 0] start_logits, end_logits = logits.permute(0, 2, 1).split(1, dim=-1) start_logits = start_logits.squeeze(-1) # (batch_size, max_query_len) end_logits = end_logits.squeeze(-1) # (batch_size, max_query_len) # print(start_logits.shape) total_loss = None if start_positions is not None and end_positions is not None: # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions.clamp_(0, ignored_index) end_positions.clamp_(0, ignored_index) loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if not return_dict: output = (start_logits, end_logits) + distilbert_output[1:] return ((total_loss,) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=distilbert_output.hidden_states, attentions=distilbert_output.attentions )
def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, start_positions=None, end_positions=None, title=None, t_mask=None, t_lens=None, c_lens=None, position_ids=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict with torch.no_grad(): hyper_inputs = self.albert( input_ids=title, attention_mask=t_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) infer_inputs = self.albert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) state = self.hypernet.init_state(hyper_inputs[0]) title_len = t_lens content_len = c_lens check = torch.isnan(hyper_inputs[0]) if True in check: print('nan in hyper_inputs!') check = torch.isnan(infer_inputs[0]) if True in check: print('nan in infer_inputs!') outputs, state = self.hypernet(hyper_inputs[0], state) check = torch.isnan(outputs) if True in check: print('nan in outputs!') h_hat_t = torch.stack([t[l - 1] for (t, l) in zip(outputs, title_len)]) if isinstance(state, tuple): state = list(state) state[0] = state[0][-1] state[1] = state[1][-1] else: state = state[-1] infer_inputs_ = infer_inputs[0].transpose(0, 1).contiguous() infer_outputs = self.infernet(state, h_hat_t, infer_inputs_) check = torch.isnan(infer_outputs) if True in check: print('nan in infer outputs!') check = torch.isnan(infer_outputs) if True in check: print('nan in state of infer outputs!') ## concat tile and context sequence_output = infer_outputs logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) total_loss = None if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions.clamp_(0, ignored_index) end_positions.clamp_(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if not return_dict: output = (start_logits, end_logits) + outputs[2:] return ((total_loss, ) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=infer_inputs.hidden_states, attentions=infer_inputs.attentions, )
def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None, is_impossible=None, pq_end_pos=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): r""" start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict discriminator_hidden_states = self.electra( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) sequence_output = discriminator_hidden_states[0] query_sequence_output, _, query_attention_mask, _ = split_ques_context( sequence_output, pq_end_pos) sequence_output = self.attention(sequence_output, query_sequence_output, query_attention_mask) sequence_output = sequence_output + discriminator_hidden_states[0] logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) first_word = sequence_output[:, 0, :] has_log = self.has_ans(first_word) total_loss = None if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) if len(is_impossible.size()) > 1: is_impossible = is_impossible.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions.clamp_(0, ignored_index) end_positions.clamp_(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) choice_loss = loss_fct(has_log, is_impossible) total_loss = (start_loss + end_loss + choice_loss) / 3 if not return_dict: output = ( start_logits, end_logits, ) + discriminator_hidden_states[1:] return ((total_loss, ) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=discriminator_hidden_states.hidden_states, attentions=discriminator_hidden_states.attentions, )