Beispiel #1
0
    def forward_qa(self, input_ids, attention_mask, start_positions,
                   end_positions):
        # Do forward pass on DistilBERT
        outputs = self.qa_model(input_ids,
                                attention_mask=attention_mask,
                                start_positions=start_positions,
                                end_positions=end_positions,
                                output_hidden_states=True)

        # Get final hidden state from DistilBERT output
        last_hidden_state = outputs["hidden_states"][-1]

        # Get output layer logits (start and end)
        logits = self.qa_outputs(last_hidden_state)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # Sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            # Use the final hidden state to get the targets from the discriminator model
            hidden = last_hidden_state[:, 0]  # same as cls_embedding
            log_prob = self.discriminator_model(hidden)
            targets = torch.ones_like(log_prob) * (1 / self.num_classes)

            # Compute KL loss
            kl_criterion = nn.KLDivLoss(reduction="batchmean")
            kld = self.discriminator_lambda * kl_criterion(log_prob, targets)

            # Compute total loss by combining QA loss with KLD loss
            loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            qa_loss = (start_loss + end_loss) / 2
            total_loss = qa_loss + kld
        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
Beispiel #2
0
    def forward(self,
                input_ids,
                attention_mask,
                start_positions=None,
                end_positions=None,
                labels=None,
                model_type=None):
        """
        Parameters
        ----------
        input_ids is shape [16, 384] or [batch_size, max_embedding_length]
        attention_mask is shape [16, 384]
        start_positions is shape [16, ]
        end_positions is shape [16, ]
        """
        if model_type == 'qa_model':
            qa_loss = self.forward_qa(input_ids, attention_mask,
                                      start_positions, end_positions)
            return qa_loss
        elif model_type == 'discriminator_model':
            discriminator_loss = self.forward_discriminator(
                input_ids, attention_mask, start_positions, end_positions,
                labels)
            return discriminator_loss
        else:
            # For evaluation
            outputs = self.qa_model(input_ids,
                                    attention_mask=attention_mask,
                                    output_hidden_states=True)
            last_hidden_state = outputs["hidden_states"][-1]
            logits = self.qa_outputs(last_hidden_state)
            start_logits, end_logits = logits.split(1, dim=-1)
            start_logits = start_logits.squeeze(-1)
            end_logits = end_logits.squeeze(-1)

            return QuestionAnsweringModelOutput(start_logits=start_logits,
                                                end_logits=end_logits)
Beispiel #3
0
    def _process_data(self, inputs, return_dict):
        inp_length = inputs[self.main_input_name].shape[1]

        # If <max_length> specified, pad inputs by zeros
        if inp_length < self.max_length:
            for name in inputs:
                shape = inputs[name].shape
                if shape[1] != self.max_length:
                    pad = np.zeros([len(shape), 2], dtype=np.int32)
                    pad[1, 1] = self.max_length - shape[1]
                    inputs[name] = np.pad(inputs[name], pad)

        # OpenVINO >= 2022.1 supports dynamic shapes input.
        if not is_openvino_api_2:
            inputs_info = self.net.input_info
            input_ids = inputs[self.main_input_name]
            if inputs_info[self.main_input_name].input_data.shape[
                    1] != input_ids.shape[1]:
                # Use batch size 1 because we process batch sequently.
                shapes = {
                    key: [1] + list(inputs[key].shape[1:])
                    for key in inputs_info
                }
                logger.info(f"Reshape model to {shapes}")
                self.net.reshape(shapes)
                self.exec_net = None
        elif is_openvino_api_2 and not self.use_dynamic_shapes:
            # TODO
            pass

        if self.exec_net is None:
            self._load_network()

        if is_openvino_api_2:
            outs = self._process_data_api_2022(inputs)
        else:
            outs = self._process_data_api_2021(inputs)

        logits = outs["output"] if "output" in outs else next(
            iter(outs.values()))

        past_key_values = None
        if self.config.architectures[0].endswith(
                "ForConditionalGeneration") and self.config.use_cache:
            past_key_values = [[]]
            for name in outs:
                if name == "output":
                    continue
                if len(past_key_values[-1]) == 4:
                    past_key_values.append([])
                past_key_values[-1].append(torch.tensor(outs[name]))

            past_key_values = tuple([tuple(val) for val in past_key_values])

        # Trunc padded values
        if inp_length != logits.shape[1]:
            logits = logits[:, :inp_length]

        if not return_dict:
            return [logits]

        arch = self.config.architectures[0]
        if arch.endswith("ForSequenceClassification"):
            return SequenceClassifierOutput(logits=logits)
        elif arch.endswith("ForQuestionAnswering"):
            return QuestionAnsweringModelOutput(start_logits=outs["output_s"],
                                                end_logits=outs["output_e"])
        else:
            return ModelOutput(logits=torch.tensor(logits),
                               past_key_values=past_key_values)
Beispiel #4
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        start_positions=None,
        end_positions=None,
        return_dict=None,
        output_hidden_states=None,
        output_attentions=None,
    ):
        r"""
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        # <- Answer Selection Part
        start_indexes = squad_metrics._get_best_indexes(start_logits.tolist(),
                                                        n_best_size=41)
        end_indexes = squad_metrics._get_best_indexes(end_logits.tolist(),
                                                      n_best_size=41)
        candidate_spans = (start_indexes, end_indexes)
        feat = self.features

        # spans in the original is structured like [passages, number of candidates, span of the answer]
        self.candidate_representation.calculate_candidate_representations(
            spans=candidate_spans, features=feat, seq_outpu=sequence_output)

        r_Ctilde = self.candidate_representation.tilda_r_Cs
        p_C = self.score_answers(r_Ctilde)

        sorted_tensor, index = torch.sort(p_C, descending=True)

        def helpfunction(ind, features):
            top1 = None
            placeholder = None
            top2 = None

            while top1 is None and top2 is None:
                print("Features[0].end_position", features[40].end_position)
                for n in ind:
                    if features[ind[n]].end_position == 0 and features[
                            ind[n]].start_position == 0:
                        continue
                    elif top1 is not None:
                        top2 = n
                    else:
                        top1 = n
                        placeholder = n - 1
                break

            if top1 is None:
                return 0, -1
            elif top2 is None:
                return top1, placeholder
            return top1, top2

        answer1, answer2 = helpfunction(index, self.features)

        answerdict = defaultdict(dict)

        for ans in [sequence_output[answer1], sequence_output[answer2]]:
            logits = self.qa_outputs(ans)
            start_logits, end_logits = logits.split(1, dim=-1)
            start_logits = start_logits.squeeze(-1)
            end_logits = end_logits.squeeze(-1)

            total_loss = None
            if start_positions is not None and end_positions is not None:
                # If we are on multi-GPU, split add a dimension
                if len(start_positions.size()) > 1:
                    start_positions = start_positions.squeeze(-1)
                if len(end_positions.size()) > 1:
                    end_positions = end_positions.squeeze(-1)
                # sometimes the start/end positions are outside our model inputs, we ignore these terms
                ignored_index = start_logits.size(1)
                start_positions.clamp_(0, ignored_index)
                end_positions.clamp_(0, ignored_index)

                loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
                start_loss = loss_fct(start_logits, start_positions[index])
                end_loss = loss_fct(end_logits, end_positions[index])
                total_loss = (start_loss + end_loss) / 2
                answerdict[ans]["loss"] = total_loss
                answerdict[ans]["start_logits"] = start_logits
                answerdict[ans]["end_logits"] = end_logits

        if not return_dict:
            output = (start_logits, end_logits) + outputs[2:]
            return ((total_loss, ) +
                    output) if total_loss is not None else output

        if answerdict[answer1]["loss"] < answerdict[answer2]["loss"]:
            gold_ans = answer1
        else:
            gold_ans = answer2

        return QuestionAnsweringModelOutput(
            loss=answerdict[gold_ans]["loss"],
            start_logits=answerdict[gold_ans]["start_logits"],
            end_logits=answerdict[gold_ans]["end_logits"],
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
Beispiel #5
0
	    def forward(
	        self,
	        input_ids=None,
	        attention_mask=None,
	        token_type_ids=None,
	        position_ids=None,
	        head_mask=None,
	        inputs_embeds=None,
	        start_positions=None,
	        end_positions=None,
	        output_attentions=None,
	        output_hidden_states=None,
	        return_dict=None,
	    ):
	        r"""
	        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
	            Labels for position (index) of the start of the labelled span for computing the token classification loss.
	            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
	            sequence are not taken into account for computing the loss.
	        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
	            Labels for position (index) of the end of the labelled span for computing the token classification loss.
	            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
	            sequence are not taken into account for computing the loss.
	        """
	        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

	        outputs = self.bert(
	            input_ids,
	            attention_mask=attention_mask,
	            token_type_ids=token_type_ids,
	            position_ids=position_ids,
	            head_mask=head_mask,
	            inputs_embeds=inputs_embeds,
	            output_attentions=output_attentions,
	            output_hidden_states=output_hidden_states,
	            return_dict=return_dict,
	        )

	        sequence_output = outputs[0]

	        logits = self.qa_outputs(sequence_output)
	        start_logits, end_logits = logits.split(1, dim=-1)
	        start_logits = start_logits.squeeze(-1)
	        end_logits = end_logits.squeeze(-1)

	        total_loss = None
	        if start_positions is not None and end_positions is not None:
	            # If we are on multi-GPU, split add a dimension
	            if len(start_positions.size()) > 1:
	                start_positions = start_positions.squeeze(-1)
	            if len(end_positions.size()) > 1:
	                end_positions = end_positions.squeeze(-1)
	            # sometimes the start/end positions are outside our model inputs, we ignore these terms
	            ignored_index = start_logits.size(1)
	            start_positions.clamp_(0, ignored_index)
	            end_positions.clamp_(0, ignored_index)

	            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
	            start_loss = loss_fct(start_logits, start_positions)
	            end_loss = loss_fct(end_logits, end_positions)
	            total_loss = (start_loss + end_loss) / 2

	        if not return_dict:
	            output = (start_logits, end_logits) + outputs[2:]
	            return ((total_loss,) + output) if total_loss is not None else output

	        return QuestionAnsweringModelOutput(
	            loss=total_loss,
	            start_logits=start_logits,
	            end_logits=end_logits,
	            hidden_states=outputs.hidden_states,
	            attentions=outputs.attentions,
	        )
Beispiel #6
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        logits = self.new_outputs(sequence_output)  # qa_outputs
        start_logits, end_logits, center_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        center_logits = center_logits.squeeze(-1)

        #
        total_loss = None
        if start_positions is not None and end_positions is not None:
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)
            # center_positions.clamp_(0, ignored_index)
            mean_positions = torch.mean(
                torch.stack([start_positions, end_positions], 0).float(), 0)
            # print('size =', start_positions.size(), mean_positions.size())
            center_positions = mean_positions.long()  # round

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            center_loss = loss_fct(center_logits, center_positions)
            total_loss = (start_loss + end_loss + center_loss) / 3

        if not return_dict:
            output = (start_logits, end_logits, end_logits
                      ) + outputs[2:]  # return center_logits or not !
            return ((total_loss, ) +
                    output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
Beispiel #7
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        # UNet
        enc_ftrs = self.encoder(sequence_output)
        out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        logits   = self.head(out)
        logits = torch.transpose(logits, 1, 2)
        
        # logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        if not return_dict:
            output = (start_logits, end_logits) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
Beispiel #8
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        global_attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        # [batch_size, sequence_length, hidden_size]
        outputs = self.longformer(
            input_ids,
            attention_mask=attention_mask,
            global_attention_mask=global_attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        if not return_dict:
            output = (start_logits, end_logits) + outputs[2:]
            return ((total_loss, ) +
                    output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        outputs = self.backbone(
            self.random_masking(input_ids) if (self.training and self.masking_ratio != 0.0) else input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        sequence_output = outputs[0]

        logits = None
        if self.head == "CCNN_LSTM_EM" or self.head == "CCNN_EM":
            exact_match_token = self.get_exact_match_token(input_ids)
            logits = self.qa_outputs((sequence_output, exact_match_token))
        else:
            logits = self.qa_outputs(sequence_output)

        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        total_loss = None
        if start_positions is not None and end_positions is not None:
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)

            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        if not return_dict:
            output = (start_logits, end_logits) + outputs[self.pooling_pos :]
            return ((total_loss,) + output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
Beispiel #10
0
    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            head_mask=None,
            inputs_embeds=None,
            start_positions=None,
            end_positions=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
    ):
        r"""
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        distilbert_output = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = distilbert_output[0]  # (bs, max_query_len, dim)
        hidden_states = self.dropout(hidden_states)  # (bs, max_query_len, dim)

        logits = gelu_new(self.qa_outputs_0(hidden_states))  # (bs, max_query_len, 2)
        logits = gelu_new(self.qa_outputs_1(logits))
        # logits = self.LayerNorm_0(logits)

        logits = self.qa_outputs(logits)
        logits = self.LayerNorm(logits)

        start_logits, end_logits = logits.split(1, dim=-1)

        start_logits = start_logits.squeeze(-1)  # (bs, max_query_len)
        end_logits = end_logits.squeeze(-1)  # (bs, max_query_len)

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        if not return_dict:
            output = (start_logits, end_logits) + distilbert_output[1:]
            return ((total_loss,) + output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=distilbert_output.hidden_states,
            attentions=distilbert_output.attentions
        )
Beispiel #11
0
    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            head_mask=None,
            inputs_embeds=None,
            start_positions=None,
            end_positions=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
    ):
        """
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        distilbert_output = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        logits = distilbert_output[0].permute(0, 2, 1)

        aspp_r3 = gelu_new(self.qa_aspp_r6_1x1(self.qa_aspp_r3(logits)))
        aspp_r6 = gelu_new(self.qa_aspp_r6_1x1(self.qa_aspp_r6(logits)))
        aspp_r12 = gelu_new(self.qa_aspp_r12_1x1(self.qa_aspp_r12(logits)))

        out_aspp = torch.cat((aspp_r3, aspp_r6, aspp_r12), 1)

        logits = gelu_new(self.qa_aspp_score(out_aspp))

        logits = logits.unsqueeze(dim=3)

        logits = self.upsampling2D(logits)
        logits = logits[:, :, :, 0]

        start_logits, end_logits = logits.permute(0, 2, 1).split(1, dim=-1)

        start_logits = start_logits.squeeze(-1)  # (batch_size, max_query_len)
        end_logits = end_logits.squeeze(-1)  # (batch_size, max_query_len)

        # print(start_logits.shape)

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        if not return_dict:
            output = (start_logits, end_logits) + distilbert_output[1:]
            return ((total_loss,) + output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=distilbert_output.hidden_states,
            attentions=distilbert_output.attentions
        )
Beispiel #12
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        start_positions=None,
        end_positions=None,
        title=None,
        t_mask=None,
        t_lens=None,
        c_lens=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        with torch.no_grad():

            hyper_inputs = self.albert(
                input_ids=title,
                attention_mask=t_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

            infer_inputs = self.albert(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

        state = self.hypernet.init_state(hyper_inputs[0])
        title_len = t_lens
        content_len = c_lens
        check = torch.isnan(hyper_inputs[0])
        if True in check:
            print('nan in hyper_inputs!')
        check = torch.isnan(infer_inputs[0])
        if True in check:
            print('nan in infer_inputs!')
        outputs, state = self.hypernet(hyper_inputs[0], state)
        check = torch.isnan(outputs)
        if True in check:
            print('nan in outputs!')
        h_hat_t = torch.stack([t[l - 1] for (t, l) in zip(outputs, title_len)])

        if isinstance(state, tuple):
            state = list(state)
            state[0] = state[0][-1]
            state[1] = state[1][-1]
        else:
            state = state[-1]
        infer_inputs_ = infer_inputs[0].transpose(0, 1).contiguous()
        infer_outputs = self.infernet(state, h_hat_t, infer_inputs_)

        check = torch.isnan(infer_outputs)
        if True in check:
            print('nan in infer outputs!')
        check = torch.isnan(infer_outputs)
        if True in check:
            print('nan in state of infer outputs!')
        ## concat tile and context
        sequence_output = infer_outputs

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        if not return_dict:
            output = (start_logits, end_logits) + outputs[2:]
            return ((total_loss, ) +
                    output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=infer_inputs.hidden_states,
            attentions=infer_inputs.attentions,
        )
Beispiel #13
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
        is_impossible=None,
        pq_end_pos=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        discriminator_hidden_states = self.electra(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        sequence_output = discriminator_hidden_states[0]

        query_sequence_output, _, query_attention_mask, _ = split_ques_context(
            sequence_output, pq_end_pos)

        sequence_output = self.attention(sequence_output,
                                         query_sequence_output,
                                         query_attention_mask)
        sequence_output = sequence_output + discriminator_hidden_states[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        first_word = sequence_output[:, 0, :]

        has_log = self.has_ans(first_word)

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            if len(is_impossible.size()) > 1:
                is_impossible = is_impossible.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            choice_loss = loss_fct(has_log, is_impossible)
            total_loss = (start_loss + end_loss + choice_loss) / 3

        if not return_dict:
            output = (
                start_logits,
                end_logits,
            ) + discriminator_hidden_states[1:]
            return ((total_loss, ) +
                    output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=discriminator_hidden_states.hidden_states,
            attentions=discriminator_hidden_states.attentions,
        )