コード例 #1
0
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, answer_choice=None,
                sentence_span_list=None, sentence_ids=None):
        sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)

        # mask: 1 for masked value and 0 for true value
        # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask)
        doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \
            layers.split_doc_sen_que(sequence_output, token_type_ids, attention_mask, sentence_span_list)

        # check_sentence_id_class_num(sentence_mask, sentence_ids)

        batch, max_sen, doc_len = doc_sen_mask.size()

        que_vec = layers.weighted_avg(que, self.que_self_attn(que, que_mask)).view(batch, 1, -1)

        doc = doc_sen.reshape(batch, max_sen * doc_len, -1)
        # [batch, max_sen, doc_len] -> [batch * max_sen, doc_len]
        word_sim = self.word_similarity(que_vec, doc).view(batch * max_sen, doc_len)

        doc = doc_sen.reshape(batch * max_sen, doc_len, -1)
        doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len)
        # [batch * max_sen, doc_len] -> [batch * max_sen, 1, doc_len] -> [batch * max_sen, 1, h]
        word_hidden = masked_softmax(word_sim, 1 - doc_mask, dim=1).unsqueeze(1).bmm(doc)
        word_hidden = word_hidden.view(batch, max_sen, -1)

        doc_vecs = layers.weighted_avg(doc, self.doc_sen_self_attn(doc, doc_mask)).view(batch, max_sen, -1)

        # [batch, 1, h]
        # sentence_hidden = self.vector_similarity(que_vec, doc_vecs, x2_mask=sentence_mask, x3=word_hidden).squeeze(1)
        # [batch, 1, max_sen]
        sentence_sim = self.vector_similarity(que_vec, doc_vecs)
        sentence_hidden = masked_softmax(sentence_sim, 1 - sentence_mask).bmm(word_hidden).squeeze(1)

        if self.freeze_predictor:
            sentence_loss = self.get_batch_evidence_loss(sentence_sim, 1 - sentence_mask, sentence_ids)
            return {'loss': sentence_loss}

        yesno_logits = self.yesno_predictor(torch.cat([sentence_hidden, que_vec.squeeze(1)], dim=1))

        sentence_scores = masked_softmax(sentence_sim, 1 - sentence_mask, dim=-1).squeeze_(1)
        output_dict = {'yesno_logits': yesno_logits,
                       'sentence_logits': sentence_scores,
                       'max_weight_index': sentence_scores.max(dim=1)[1],
                       'max_weight': sentence_scores.max(dim=1)[0]}
        loss = 0
        if answer_choice is not None:
            choice_loss = F.cross_entropy(yesno_logits, answer_choice, ignore_index=-1)
            loss += choice_loss
            if sentence_ids is not None:
                sentence_loss = self.get_batch_evidence_loss(sentence_sim, 1 - sentence_mask, sentence_ids)
                loss += self.evidence_lam * sentence_loss
        output_dict['loss'] = loss
        output_dict['sentence_sim'] = sentence_sim.detach().cpu().float()
        output_dict['sentence_mask'] = (1 - sentence_mask).detach().cpu().float()
        return output_dict
コード例 #2
0
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, answer_choice=None,
                sentence_span_list=None, sentence_ids=None):
        sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)

        # mask: 1 for masked value and 0 for true value
        # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask)
        doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \
            layers.split_doc_sen_que(sequence_output, token_type_ids, attention_mask, sentence_span_list)

        batch, max_sen, doc_len = doc_sen_mask.size()

        que_vec = layers.weighted_avg(que, self.que_self_attn(que, que_mask)).view(batch, 1, -1)

        doc = doc_sen.reshape(batch * max_sen, doc_len, -1)
        doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len)

        doc_vecs = layers.weighted_avg(doc, self.doc_sen_self_attn(doc, doc_mask)).view(batch, max_sen, -1)

        sentence_sim = self.vector_similarity(que_vec, doc_vecs)
        sentence_hidden = masked_softmax(sentence_sim, 1 - sentence_mask).bmm(doc_vecs).squeeze(1)

        yesno_logits = self.yesno_predictor(torch.cat([sentence_hidden, que_vec.squeeze(1)], dim=1))

        sentence_scores = masked_softmax(sentence_sim, 1 - sentence_mask, dim=-1).squeeze_(1)
        output_dict = {'yesno_logits': yesno_logits,
                       'sentence_logits': sentence_scores,
                       'max_weight_index': sentence_scores.max(dim=1)[1],
                       'max_weight': sentence_scores.max(dim=1)[0]}
        loss = 0
        if answer_choice is not None:
            choice_loss = F.cross_entropy(yesno_logits, answer_choice, ignore_index=-1)
            loss += choice_loss
        if sentence_ids is not None:
            log_sentence_sim = masked_log_softmax(sentence_sim.squeeze(1), 1 - sentence_mask, dim=-1)
            sentence_loss = self.evidence_lam * F.nll_loss(log_sentence_sim, sentence_ids, ignore_index=-1)
            loss += sentence_loss
        output_dict['loss'] = loss
        return output_dict
コード例 #3
0
    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                answer_choice=None,
                sentence_span_list=None,
                sentence_ids=None,
                sentence_label=None):
        sequence_output, _ = self.bert(input_ids,
                                       token_type_ids,
                                       attention_mask,
                                       output_all_encoded_layers=False)

        # mask: 1 for masked value and 0 for true value
        # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask)
        doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \
            layers.split_doc_sen_que(sequence_output, token_type_ids, attention_mask, sentence_span_list)

        # check_sentence_id_class_num(sentence_mask, sentence_ids)

        batch, max_sen, doc_len = doc_sen_mask.size()
        # que_len = que_mask.size(1)

        que_vec = layers.weighted_avg(que, self.que_self_attn(que,
                                                              que_mask)).view(
                                                                  batch, 1, -1)

        doc = doc_sen.reshape(batch, max_sen * doc_len, -1)
        # [batch, max_sen, doc_len] -> [batch * max_sen, doc_len]
        word_sim = self.word_similarity(que_vec,
                                        doc).view(batch * max_sen, doc_len)

        doc = doc_sen.reshape(batch * max_sen, doc_len, -1)
        doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len)
        # [batch * max_sen, doc_len] -> [batch * max_sen, 1, doc_len] -> [batch * max_sen, 1, h]
        word_hidden = masked_softmax(word_sim, 1 - doc_mask,
                                     dim=1).unsqueeze(1).bmm(doc)
        word_hidden = word_hidden.view(batch, max_sen, -1)

        doc_vecs = layers.weighted_avg(doc,
                                       self.doc_sen_self_attn(doc,
                                                              doc_mask)).view(
                                                                  batch,
                                                                  max_sen, -1)

        # [batch, 1, max_sen]
        sentence_sim = self.vector_similarity(que_vec, doc_vecs)
        sentence_scores = masked_softmax(sentence_sim, 1 - sentence_mask)
        sentence_hidden = sentence_scores.bmm(word_hidden).squeeze(1)

        yesno_logits = self.yesno_predictor(
            torch.cat([sentence_hidden, que_vec.squeeze(1)], dim=1))

        sentence_scores = sentence_scores.squeeze(1)
        max_sentence_score = sentence_scores.max(dim=-1)
        output_dict = {
            'yesno_logits': yesno_logits,
            'sentence_logits': sentence_scores,
            'max_weight': max_sentence_score[0],
            'max_weight_index': max_sentence_score[1]
        }
        loss = 0
        if answer_choice is not None:
            choice_loss = F.cross_entropy(yesno_logits,
                                          answer_choice,
                                          ignore_index=-1)
            loss += choice_loss
        if sentence_ids is not None:
            log_sentence_sim = masked_log_softmax(sentence_sim.squeeze(1),
                                                  1 - sentence_mask,
                                                  dim=-1)
            sentence_loss = self.evidence_lam * F.nll_loss(
                log_sentence_sim, sentence_ids, ignore_index=-1)
            loss += sentence_loss
            if self.add_entropy:
                no_evidence_mask = (sentence_ids != -1)
                entropy = layers.get_masked_entropy(sentence_scores,
                                                    mask=no_evidence_mask)
                loss += self.evidence_lam * entropy
        if sentence_label is not None:
            # sentence_label: batch * List[k]
            # [batch, max_sen]
            # log_sentence_sim = masked_log_softmax(sentence_sim.squeeze(1), 1 - sentence_mask, dim=-1)
            sentence_prob = 1 - sentence_scores
            log_sentence_sim = -torch.log(sentence_prob + 1e-15)
            negative_loss = 0
            for b in range(batch):
                for sen_id, k in enumerate(sentence_label[b]):
                    negative_loss += k * log_sentence_sim[b][sen_id]
            negative_loss /= batch
            loss += self.negative_lam * negative_loss
        output_dict['loss'] = loss
        return output_dict
コード例 #4
0
    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                answer_choice=None,
                sentence_span_list=None,
                sentence_ids=None):
        sequence_output, _ = self.bert(input_ids,
                                       token_type_ids,
                                       attention_mask,
                                       output_all_encoded_layers=False)

        # mask: 1 for masked value and 0 for true value
        # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask)
        doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \
            layers.split_doc_sen_que(sequence_output, token_type_ids, attention_mask, sentence_span_list)

        batch, max_sen, doc_len = doc_sen_mask.size()
        # que_len = que_mask.size(1)

        que_vec = layers.weighted_avg(que, self.que_self_attn(que,
                                                              que_mask)).view(
                                                                  batch, 1, -1)

        doc = doc_sen.reshape(batch, max_sen * doc_len, -1)
        # [batch, max_sen, doc_len] -> [batch * max_sen, doc_len]
        word_sim = self.word_similarity(que_vec,
                                        doc).view(batch * max_sen, doc_len)

        doc = doc_sen.reshape(batch * max_sen, doc_len, -1)
        doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len)
        # [batch * max_sen, doc_len] -> [batch * max_sen, 1, doc_len] -> [batch * max_sen, 1, h]
        word_hidden = masked_softmax(word_sim, 1 - doc_mask,
                                     dim=1).unsqueeze(1).bmm(doc)
        word_hidden = word_hidden.view(batch, max_sen, -1)

        doc_vecs = layers.weighted_avg(doc,
                                       self.doc_sen_self_attn(doc,
                                                              doc_mask)).view(
                                                                  batch,
                                                                  max_sen, -1)

        # [batch, 1, h]
        # sentence_hidden = self.vector_similarity(que_vec, doc_vecs, x2_mask=sentence_mask, x3=word_hidden).squeeze(1)
        # [batch, 1, max_sen]
        sentence_sim = self.vector_similarity(que_vec, doc_vecs)
        # sentence_hidden = self.hard_sample(sentence_sim, use_gumbel=self.use_gumbel, dim=-1,
        #                                    hard=True, mask=(1 - sentence_mask)).bmm(word_hidden).squeeze(1)
        if self.training:
            _sample_prob, _sample_log_prob = self.sample_one_hot(
                sentence_sim, 1 - sentence_mask)
            loss_and_reward, _ = self.reward_func(word_hidden, que_vec,
                                                  answer_choice, _sample_prob,
                                                  _sample_log_prob)
            output_dict = {'loss': loss_and_reward}
        else:
            _prob, _ = self.sample_one_hot(sentence_sim, 1 - sentence_mask)
            loss, _yesno_logits = self.simple_step(word_hidden, que_vec,
                                                   answer_choice, _prob)
            sentence_scores = masked_softmax(sentence_sim,
                                             1 - sentence_mask,
                                             dim=-1).squeeze_(1)
            output_dict = {
                'max_weight': sentence_scores.max(dim=1)[0],
                'max_weight_index': sentence_scores.max(dim=1)[1],
                'sentence_logits': sentence_scores,
                'loss': loss,
                'yesno_logits': _yesno_logits
            }

        return output_dict
コード例 #5
0
    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                answer_choice=None,
                sentence_span_list=None,
                sentence_ids=None,
                sentence_label=None):
        sequence_output, _ = self.bert(input_ids,
                                       token_type_ids,
                                       attention_mask,
                                       output_all_encoded_layers=False)

        # mask: 1 for masked value and 0 for true value
        # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask)
        doc_sen, que, doc_sen_mask, que_mask, sentence_mask, cls_h = \
            layers.split_doc_sen_que(sequence_output, token_type_ids, attention_mask, sentence_span_list, return_cls=True)

        batch, max_sen, doc_len = doc_sen_mask.size()
        que_mask = 1 - que_mask
        doc_sen_mask = 1 - doc_sen_mask
        sentence_mask = 1 - sentence_mask

        cls_h.unsqueeze_(1)
        que_vec = masked_softmax(self.que_word_sum(cls_h, que),
                                 que_mask).bmm(que)

        doc = doc_sen.reshape(batch, max_sen * doc_len, -1)
        doc_word_sim = self.doc_word_sum(cls_h,
                                         doc).view(batch * max_sen, doc_len)
        doc = doc_sen.reshape(batch * max_sen, doc_len, -1)
        doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len)
        doc_sen_h = masked_softmax(doc_word_sim, doc_mask,
                                   dim=-1).unsqueeze(1).bmm(doc).view(
                                       batch, max_sen, -1)

        sentence_sim = self.doc_sen_sum(que_vec, doc_sen_h)
        sentence_scores = masked_softmax(sentence_sim, sentence_mask)
        doc_vec = sentence_scores.bmm(doc_sen_h).squeeze(1)

        yesno_logits = self.yesno_predictor(
            torch.cat([doc_vec, que_vec.squeeze(1)], dim=1))

        output_dict = {
            'yesno_logits': yesno_logits,
            'sentence_scores': sentence_scores
        }
        loss = 0
        if answer_choice is not None:
            loss += F.cross_entropy(yesno_logits,
                                    answer_choice,
                                    ignore_index=-1)
        if sentence_ids is not None:
            log_sentence_sim = masked_log_softmax(sentence_sim.squeeze(1),
                                                  sentence_mask,
                                                  dim=-1)
            sentence_loss = self.evidence_lam * F.nll_loss(
                log_sentence_sim, sentence_ids, ignore_index=-1)
            loss += sentence_loss
        if self.cls_sup and answer_choice is not None:
            extra_yesno_logits = self.extra_predictor(cls_h.squeeze(1))
            extra_choice_loss = self.extra_yesno_lam * F.cross_entropy(
                extra_yesno_logits, answer_choice, ignore_index=-1)
            loss += extra_choice_loss
        output_dict['loss'] = loss
        return output_dict
コード例 #6
0
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, answer_choice=None,
                sentence_span_list=None, sentence_ids=None):
        sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)

        # mask: 1 for masked value and 0 for true value
        # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask)
        doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \
            layers.split_doc_sen_que(sequence_output, token_type_ids, attention_mask, sentence_span_list)

        # check_sentence_id_class_num(sentence_mask, sentence_ids)

        batch, max_sen, doc_len = doc_sen_mask.size()
        # que_len = que_mask.size(1)

        que_vec = layers.weighted_avg(que, self.que_self_attn(que, que_mask)).view(batch, 1, -1)

        doc = doc_sen.reshape(batch, max_sen * doc_len, -1)
        # [batch, max_sen, doc_len] -> [batch * max_sen, doc_len]
        word_sim = self.word_similarity(que_vec, doc).view(batch * max_sen, doc_len)

        doc = doc_sen.reshape(batch * max_sen, doc_len, -1)
        doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len)
        # [batch * max_sen, doc_len] -> [batch * max_sen, 1, doc_len] -> [batch * max_sen, 1, h]
        word_hidden = masked_softmax(word_sim, 1 - doc_mask, dim=1).unsqueeze(1).bmm(doc)
        word_hidden = word_hidden.view(batch, max_sen, -1)

        doc_vecs = layers.weighted_avg(doc, self.doc_sen_self_attn(doc, doc_mask)).view(batch, max_sen, -1)

        # [batch, 1, h]
        # sentence_hidden = self.vector_similarity(que_vec, doc_vecs, x2_mask=sentence_mask, x3=word_hidden).squeeze(1)
        # [batch, 1, max_sen]
        sentence_sim = self.vector_similarity(que_vec, doc_vecs)
        # Test performance of only evidence sentences
        # if not self.training:
        #     max_index = masked_softmax(sentence_sim, 1 - sentence_mask).max(dim=-1, keepdim=True)[1]
        #     one_hot_vec = torch.zeros_like(sentence_sim).scatter_(-1, max_index, 1.0)
        #     sentence_hidden = one_hot_vec.bmm(word_hidden).squeeze(1)
        # else:
        # Test performance of only max k evidence sentences
        # if not self.training:
        #     k_max_mask = rep_layers.get_k_max_mask(sentence_sim * (1 - sentence_mask.unsqueeze(1)).to(sentence_sim.dtype), dim=-1, k=2)
        #     sentence_hidden = masked_softmax(sentence_sim, k_max_mask).bmm(word_hidden).squeeze(1)
        # else:
        sentence_hidden = masked_softmax(sentence_sim, 1 - sentence_mask).bmm(word_hidden).squeeze(1)

        yesno_logits = self.yesno_predictor(torch.cat([sentence_hidden, que_vec.squeeze(1)], dim=1))

        sentence_scores = masked_softmax(sentence_sim, 1 - sentence_mask, dim=-1).squeeze_(1)
        output_dict = {'yesno_logits': yesno_logits,
                       'sentence_logits': sentence_scores,
                       'max_weight_index': sentence_scores.max(dim=1)[1],
                       'max_weight': sentence_scores.max(dim=1)[0]}
        loss = 0
        if answer_choice is not None:
            choice_loss = F.cross_entropy(yesno_logits, answer_choice, ignore_index=-1)
            loss += choice_loss
        if sentence_ids is not None:
            log_sentence_sim = masked_log_softmax(sentence_sim.squeeze(1), 1 - sentence_mask, dim=-1)
            sentence_loss = self.evidence_lam * F.nll_loss(log_sentence_sim, sentence_ids, ignore_index=-1)
            loss += sentence_loss
        output_dict['loss'] = loss
        return output_dict
コード例 #7
0
    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                answer_choice=None,
                sentence_span_list=None,
                sentence_ids=None,
                sentence_label=None):
        sequence_output, _ = self.bert(input_ids,
                                       token_type_ids,
                                       attention_mask,
                                       output_all_encoded_layers=False)

        # mask: 1 for masked value and 0 for true value
        # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask)
        doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \
            layers.split_doc_sen_que(sequence_output, token_type_ids, attention_mask, sentence_span_list)

        # check_sentence_id_class_num(sentence_mask, sentence_ids)

        batch, max_sen, doc_len = doc_sen_mask.size()
        # que_len = que_mask.size(1)

        que_vec = layers.weighted_avg(que, self.que_self_attn(que,
                                                              que_mask)).view(
                                                                  batch, 1, -1)

        doc = doc_sen.reshape(batch, max_sen * doc_len, -1)
        # [batch, max_sen, doc_len] -> [batch * max_sen, doc_len]
        word_sim = self.word_similarity(que_vec,
                                        doc).view(batch * max_sen, doc_len)

        doc = doc_sen.reshape(batch * max_sen, doc_len, -1)
        doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len)
        # [batch * max_sen, doc_len] -> [batch * max_sen, 1, doc_len] -> [batch * max_sen, 1, h]
        word_hidden = masked_softmax(word_sim, 1 - doc_mask,
                                     dim=1).unsqueeze(1).bmm(doc)
        word_hidden = word_hidden.view(batch, max_sen, -1)

        doc_vecs = layers.weighted_avg(doc,
                                       self.doc_sen_self_attn(doc,
                                                              doc_mask)).view(
                                                                  batch,
                                                                  max_sen, -1)

        # [batch, 1, max_sen]
        sentence_sim = self.vector_similarity(que_vec, doc_vecs)
        softmax_sim = masked_softmax(sentence_sim, 1 - sentence_mask, dim=-1)
        sentence_hidden = softmax_sim.bmm(word_hidden).squeeze(1)
        softmax_sim.squeeze_(dim=1)

        # yesno_logits = self.yesno_predictor(torch.cat([sentence_hidden, que_vec.squeeze(1)], dim=1))
        output_dict = {
            'sentence_logits': softmax_sim,
            'max_weight_index': softmax_sim.max(dim=1)[1],
            'max_weight': softmax_sim.max(dim=1)[0]
        }
        if sentence_ids is not None:
            log_sentence_sim = masked_log_softmax(sentence_sim.squeeze(1),
                                                  1 - sentence_mask,
                                                  dim=-1)
            sentence_loss = self.evidence_lam * F.nll_loss(
                log_sentence_sim, sentence_ids, ignore_index=-1)
            output_dict['loss'] = sentence_loss
            return output_dict
        return output_dict