Esempio n. 1
0
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, answer_choice=None,
                sentence_span_list=None, sentence_ids=None):
        sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)

        # mask: 1 for masked value and 0 for true value
        # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask)
        doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \
            rep_layers.split_doc_sen_que(sequence_output, token_type_ids, attention_mask, sentence_span_list)

        # check_sentence_id_class_num(sentence_mask, sentence_ids)

        batch, max_sen, doc_len = doc_sen_mask.size()

        que_vec = self.que_self_attn(que, que_mask).view(batch, 1, -1)

        doc = doc_sen.reshape(batch, max_sen * doc_len, -1)
        # [batch, max_sen, doc_len] -> [batch * max_sen, doc_len]
        word_sim = self.word_similarity(que_vec, doc).view(batch * max_sen, doc_len)

        doc = doc_sen.reshape(batch * max_sen, doc_len, -1)
        doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len)
        word_hidden = rep_layers.masked_softmax(word_sim, doc_mask, dim=1).unsqueeze(1).bmm(doc)

        word_hidden = word_hidden.view(batch, max_sen, -1)

        doc_vecs = self.doc_sen_self_attn(doc, doc_mask).view(batch, max_sen, -1)

        # [batch, 1, h]
        # sentence_hidden = self.vector_similarity(que_vec, doc_vecs, x2_mask=sentence_mask, x3=word_hidden).squeeze(1)
        # [batch, 1, max_sen]
        sentence_sim = self.vector_similarity(que_vec, doc_vecs)
        sentence_alpha = rep_layers.masked_softmax(sentence_sim, sentence_mask)
        sentence_hidden = sentence_alpha.bmm(word_hidden).squeeze(1)

        if self.freeze_predictor:
            sentence_loss = self.get_batch_evidence_loss(sentence_sim, sentence_mask, sentence_ids)
            return {'loss': sentence_loss}

        yesno_logits = self.yesno_predictor(torch.cat([sentence_hidden, que_vec.squeeze(1)], dim=1))

        output_dict = {'yesno_logits': torch.softmax(yesno_logits, dim=-1).detach().cpu().float()}

        loss = 0
        if answer_choice is not None:
            choice_loss = F.cross_entropy(yesno_logits, answer_choice, ignore_index=-1)
            loss += choice_loss
            if sentence_ids is not None:
                sentence_loss = self.get_batch_evidence_loss(sentence_sim, sentence_mask, sentence_ids)
                # print('Sentence loss: ')
                # print(sentence_loss)
                loss += self.evidence_lam * sentence_loss
        output_dict['loss'] = loss
        output_dict['sentence_sim'] = sentence_sim.detach().cpu().float()
        output_dict['sentence_mask'] = sentence_mask.detach().cpu().float()
        return output_dict
 def sample_one_hot(self, _similarity, _mask):
     _probability = rep_layers.masked_softmax(_similarity, _mask)
     dtype = _probability.dtype
     _probability = _probability.float()
     # _log_probability = masked_log_softmax(_similarity, _mask)
     if self.training:
         _distribution = Categorical(_probability)
         _sample_index = _distribution.sample((self.sample_steps, ))
         logger.debug(str(_sample_index.size()))
         new_shape = (self.sample_steps, ) + _similarity.size()
         logger.debug(str(new_shape))
         _sample_one_hot = F.one_hot(_sample_index,
                                     num_classes=_similarity.size(-1))
         # _sample_one_hot = _similarity.new_zeros(new_shape).scatter(-1, _sample_index.unsqueeze(-1), 1.0)
         logger.debug(str(_sample_one_hot.size()))
         _log_prob = _distribution.log_prob(
             _sample_index)  # sample_steps, batch, 1
         assert _log_prob.size() == new_shape[:-1], (_log_prob.size(),
                                                     new_shape)
         _sample_one_hot = _sample_one_hot.transpose(
             0, 1)  # batch, sample_steps, 1, max_sen
         _log_prob = _log_prob.transpose(0, 1)  # batch, sample_steps, 1
         return _sample_one_hot.to(dtype=dtype), _log_prob.to(dtype=dtype)
     else:
         _max_index = _probability.float().max(dim=-1, keepdim=True)[1]
         _one_hot = torch.zeros_like(_similarity).scatter_(
             -1, _max_index, 1.0)
         # _log_prob = _log_probability.gather(-1, _max_index)
         return _one_hot, None
 def hard_sample(self, logits, use_gumbel, dim=-1, hard=True, mask=None):
     if use_gumbel:
         if self.training:
             probs = rep_layers.gumbel_softmax(logits,
                                               mask=mask,
                                               hard=hard,
                                               dim=dim)
             return probs
         else:
             probs = rep_layers.masked_softmax(logits, mask, dim=dim)
             index = probs.max(dim, keepdim=True)[1]
             y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0)
             return y_hard
     else:
         pass
    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                labels=None,
                sentence_span_list=None,
                sentence_ids=None,
                max_sentences: int = 0):
        flat_input_ids = input_ids.view(-1, input_ids.size(-1))
        flat_token_type_ids = token_type_ids.view(
            -1,
            token_type_ids.size(-1)) if token_type_ids is not None else None
        flat_attention_mask = attention_mask.view(
            -1,
            attention_mask.size(-1)) if attention_mask is not None else None
        sequence_output, _ = self.bert(flat_input_ids,
                                       flat_token_type_ids,
                                       flat_attention_mask,
                                       output_all_encoded_layers=False)

        # mask: 1 for masked value and 0 for true value
        # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask)
        doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \
            rep_layers.split_doc_sen_que(sequence_output, flat_token_type_ids, flat_attention_mask, sentence_span_list,
                                         max_sentences=max_sentences)

        batch, max_sen, doc_len = doc_sen_mask.size()
        # que_len = que_mask.size(1)

        # que_vec = layers.weighted_avg(que, self.que_self_attn(que, que_mask)).view(batch, 1, -1)
        que_vec = self.que_self_attn(que, que_mask).view(batch, 1, -1)

        doc = doc_sen.reshape(batch, max_sen * doc_len, -1)
        word_sim = self.word_similarity(que_vec,
                                        doc).view(batch * max_sen, doc_len)
        doc = doc_sen.reshape(batch * max_sen, doc_len, -1)
        doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len)
        word_hidden = rep_layers.masked_softmax(word_sim, doc_mask,
                                                dim=1).unsqueeze(1).bmm(doc)

        word_hidden = word_hidden.view(batch, max_sen, -1)

        doc_vecs = self.doc_sen_self_attn(doc,
                                          doc_mask).view(batch, max_sen, -1)

        sentence_sim = self.vector_similarity(que_vec, doc_vecs)
        sentence_hidden = self.hard_sample(
            sentence_sim,
            use_gumbel=self.use_gumbel,
            dim=-1,
            hard=True,
            mask=sentence_mask).bmm(word_hidden).squeeze(1)

        choice_logits = self.classifier(
            torch.cat([sentence_hidden, que_vec.squeeze(1)],
                      dim=1)).reshape(-1, self.num_choices)

        sentence_scores = rep_layers.masked_softmax(sentence_sim,
                                                    sentence_mask,
                                                    dim=-1).squeeze_(1)
        output_dict = {
            'choice_logits':
            choice_logits.float(),
            'sentence_logits':
            sentence_scores.reshape(choice_logits.size(0), self.num_choices,
                                    max_sen).detach().cpu().float(),
        }
        loss = 0
        if labels is not None:
            choice_loss = F.cross_entropy(choice_logits, labels)
            loss += choice_loss
        if sentence_ids is not None:
            log_sentence_sim = rep_layers.masked_log_softmax(
                sentence_sim.squeeze(1), sentence_mask, dim=-1)
            sentence_loss = F.nll_loss(log_sentence_sim,
                                       sentence_ids.view(batch),
                                       reduction='sum',
                                       ignore_index=-1)
            loss += self.evidence_lam * sentence_loss / choice_logits.size(0)
        output_dict['loss'] = loss
        return output_dict
    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                labels=None,
                sentence_span_list=None,
                sentence_ids=None,
                max_sentences: int = 0):
        flat_input_ids = input_ids.view(-1, input_ids.size(-1))
        flat_token_type_ids = token_type_ids.view(
            -1,
            token_type_ids.size(-1)) if token_type_ids is not None else None
        flat_attention_mask = attention_mask.view(
            -1,
            attention_mask.size(-1)) if attention_mask is not None else None
        sequence_output, _ = self.bert(flat_input_ids,
                                       flat_token_type_ids,
                                       flat_attention_mask,
                                       output_all_encoded_layers=False)

        # mask: 1 for masked value and 0 for true value
        # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask)
        doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \
            rep_layers.split_doc_sen_que(sequence_output, flat_token_type_ids, flat_attention_mask, sentence_span_list,
                                         max_sentences=max_sentences)

        batch, max_sen, doc_len = doc_sen_mask.size()

        que_vec = self.que_self_attn(que, que_mask).view(batch, 1, -1)

        doc = doc_sen.reshape(batch, max_sen * doc_len, -1)
        word_sim = self.word_similarity(que_vec,
                                        doc).view(batch * max_sen, doc_len)
        doc = doc_sen.reshape(batch * max_sen, doc_len, -1)
        doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len)
        word_hidden = rep_layers.masked_softmax(word_sim, doc_mask,
                                                dim=1).unsqueeze(1).bmm(doc)

        word_hidden = word_hidden.view(batch, max_sen, -1)

        doc_vecs = self.doc_sen_self_attn(doc,
                                          doc_mask).view(batch, max_sen, -1)

        sentence_sim = self.vector_similarity(que_vec, doc_vecs)
        if self.training:
            _sample_prob, _sample_log_prob = self.sample_one_hot(
                sentence_sim, sentence_mask)
            loss_and_reward, _ = self.reward_func(word_hidden, que_vec, labels,
                                                  _sample_prob,
                                                  _sample_log_prob)
            output_dict = {'loss': loss_and_reward}
        else:
            _prob, _ = self.sample_one_hot(sentence_sim, sentence_mask)
            loss, _choice_logits = self.simple_step(word_hidden, que_vec,
                                                    labels, _prob)
            sentence_scores = rep_layers.masked_softmax(sentence_sim,
                                                        sentence_mask,
                                                        dim=-1).squeeze_(1)
            output_dict = {
                'sentence_logits': sentence_scores.float(),
                'loss': loss,
                'choice_logits': _choice_logits.float()
            }

        return output_dict
    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                answer_choice=None,
                sentence_span_list=None,
                sentence_ids=None):
        sequence_output, _ = self.bert(input_ids,
                                       token_type_ids,
                                       attention_mask,
                                       output_all_encoded_layers=False)

        # mask: 1 for masked value and 0 for true value
        # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask)
        doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \
            rep_layers.split_doc_sen_que(sequence_output, token_type_ids, attention_mask, sentence_span_list)

        batch, max_sen, doc_len = doc_sen_mask.size()

        que_vec = self.que_self_attn(que, que_mask).view(batch, 1, -1)

        doc = doc_sen.reshape(batch, max_sen * doc_len, -1)
        # [batch, max_sen, doc_len] -> [batch * max_sen, doc_len]
        word_sim = self.word_similarity(que_vec,
                                        doc).view(batch * max_sen, doc_len)

        doc = doc_sen.reshape(batch * max_sen, doc_len, -1)
        doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len)
        word_hidden = rep_layers.masked_softmax(word_sim, doc_mask,
                                                dim=1).unsqueeze(1).bmm(doc)

        word_hidden = word_hidden.view(batch, max_sen, -1)

        doc_vecs = self.doc_sen_self_attn(doc,
                                          doc_mask).view(batch, max_sen, -1)

        # [batch, 1, h]
        # sentence_hidden = self.vector_similarity(que_vec, doc_vecs, x2_mask=sentence_mask, x3=word_hidden).squeeze(1)
        # [batch, 1, max_sen]
        sentence_sim = self.vector_similarity(que_vec, doc_vecs)
        # sentence_hidden = self.hard_sample(sentence_sim, use_gumbel=self.use_gumbel, dim=-1,
        #                                    hard=True, mask=(1 - sentence_mask)).bmm(word_hidden).squeeze(1)
        if self.training:
            _sample_prob, _sample_log_prob = self.sample_one_hot(
                sentence_sim, sentence_mask)
            loss_and_reward, _ = self.reward_func(word_hidden, que_vec,
                                                  answer_choice, _sample_prob,
                                                  _sample_log_prob)
            output_dict = {'loss': loss_and_reward}
        else:
            _prob, _ = self.sample_one_hot(sentence_sim, sentence_mask)
            loss, _yesno_logits = self.simple_step(word_hidden, que_vec,
                                                   answer_choice, _prob)
            sentence_scores = rep_layers.masked_softmax(sentence_sim,
                                                        sentence_mask,
                                                        dim=-1).squeeze_(1)
            output_dict = {
                'sentence_logits':
                sentence_scores.float().detach().cpu().tolist(),
                'loss':
                loss,
                'yesno_logits':
                torch.softmax(_yesno_logits, dim=-1).float().detach().cpu(),
                'sentence_sim':
                sentence_sim.float().detach().cpu(),
                'sentence_mask':
                sentence_mask.detach().cpu().float()
            }

        return output_dict
    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                answer_choice=None,
                sentence_span_list=None,
                sentence_ids=None):
        sequence_output, _ = self.bert(input_ids,
                                       token_type_ids,
                                       attention_mask,
                                       output_all_encoded_layers=False)

        # mask: 1 for masked value and 0 for true value
        # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask)
        doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \
            rep_layers.split_doc_sen_que(sequence_output, token_type_ids, attention_mask, sentence_span_list)

        batch, max_sen, doc_len = doc_sen_mask.size()

        que_vec = self.que_self_attn(que, que_mask).view(batch, 1, -1)

        doc = doc_sen.reshape(batch, max_sen * doc_len, -1)
        # [batch, max_sen, doc_len] -> [batch * max_sen, doc_len]
        word_sim = self.word_similarity(que_vec,
                                        doc).view(batch * max_sen, doc_len)

        doc = doc_sen.reshape(batch * max_sen, doc_len, -1)
        doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len)
        word_hidden = rep_layers.masked_softmax(word_sim, doc_mask,
                                                dim=1).unsqueeze(1).bmm(doc)

        word_hidden = word_hidden.view(batch, max_sen, -1)

        doc_vecs = self.doc_sen_self_attn(doc,
                                          doc_mask).view(batch, max_sen, -1)

        # [batch, 1, h]
        # sentence_hidden = self.vector_similarity(que_vec, doc_vecs, x2_mask=sentence_mask, x3=word_hidden).squeeze(1)
        # [batch, 1, max_sen]
        sentence_sim = self.vector_similarity(que_vec, doc_vecs)
        sentence_hidden = self.hard_sample(
            sentence_sim,
            use_gumbel=self.use_gumbel,
            dim=-1,
            hard=True,
            mask=sentence_mask).bmm(word_hidden).squeeze(1)

        yesno_logits = self.yesno_predictor(
            torch.cat([sentence_hidden, que_vec.squeeze(1)], dim=1))

        sentence_scores = rep_layers.masked_softmax(sentence_sim,
                                                    sentence_mask,
                                                    dim=-1).squeeze_(1)
        output_dict = {
            'yesno_logits':
            torch.softmax(yesno_logits, dim=-1).detach().cpu().float(),
            'sentence_logits':
            sentence_scores
        }
        loss = 0
        if answer_choice is not None:
            choice_loss = F.cross_entropy(yesno_logits,
                                          answer_choice,
                                          ignore_index=-1)
            loss += choice_loss
        # if sentence_ids is not None:
        #     log_sentence_sim = rep_layers.masked_log_softmax(sentence_sim.squeeze(1), sentence_mask, dim=-1)
        #     sentence_loss = self.evidence_lam * F.nll_loss(log_sentence_sim, sentence_ids, ignore_index=-1)
        #     loss += sentence_loss
        output_dict['loss'] = loss
        return output_dict