Ejemplo n.º 1
0
    def forward(self, q_ids=None, char_ids=None, word_ids=None, token_type_ids=None, subject_ids=None,
                subject_labels=None,
                object_labels=None, eval_file=None,
                is_eval=False):

        mask = char_ids != 0

        seq_mask = char_ids.eq(0)

        char_emb = self.char_emb(char_ids)
        word_emb = self.word_convert_char(self.word_emb(word_ids))
        # word_emb = self.word_emb(word_ids)
        emb = char_emb + word_emb
        # emb = char_emb
        # subject_encoder = sent_encoder + self.token_entity_emb(token_type_id)
        sent_encoder = self.first_sentence_encoder(emb, seq_mask)
        ent_emission = self.ent_emission(sent_encoder)
        if not is_eval:
            # subject_encoder = self.token_entity_emb(token_type_ids)
            # context_encoder = bert_encoder + subject_encoder

            sub_start_encoder = batch_gather(sent_encoder, subject_ids[:, 0])
            sub_end_encoder = batch_gather(sent_encoder, subject_ids[:, 1])
            subject = torch.cat([sub_start_encoder, sub_end_encoder], 1)
            context_encoder = self.LayerNorm(sent_encoder, subject)
            context_encoder = self.transformer_encoder(context_encoder.transpose(1, 0),
                                                       src_key_padding_mask=seq_mask).transpose(0, 1)

            ent_loss = -self.ent_crf(ent_emission, subject_labels, mask=mask, reduction='mean')

            emission = self.emission(context_encoder)
            po_loss = -self.crf(emission, object_labels, mask=mask, reduction='mean')
            loss = ent_loss + po_loss

            return loss

        else:
            subject_preds = self.ent_crf.decode(emissions=ent_emission, mask=mask)
            answer_list = list()
            for qid, sub_pred in zip(q_ids.cpu().numpy(), subject_preds):
                seq_len = min(len(eval_file[qid].context), self.max_len)
                tag_list = list()
                j = 0
                while j < seq_len:
                    end = j
                    flag = True

                    if sub_pred[j] == 1:
                        start = j
                        for k in range(start + 1, seq_len):
                            if sub_pred[k] != sub_pred[start] + 1:
                                end = k - 1
                                flag = False
                                break
                        if flag:
                            end = seq_len - 1
                        tag_list.append((start, end))
                    j = end + 1

                answer_list.append(tag_list)

            qid_ids, sent_encoders, pass_ids, subject_ids, token_type_ids = [], [], [], [], []
            for i, subjects in enumerate(answer_list):
                if subjects:
                    qid = q_ids[i].unsqueeze(0).expand(len(subjects))
                    pass_tensor = char_ids[i, :].unsqueeze(0).expand(len(subjects), char_ids.size(1))
                    new_sent_encoder = sent_encoder[i, :, :].unsqueeze(0).expand(len(subjects), sent_encoder.size(1),
                                                                                 sent_encoder.size(2))

                    token_type_id = torch.zeros((len(subjects), char_ids.size(1)), dtype=torch.long)
                    for index, (start, end) in enumerate(subjects):
                        token_type_id[index, start:end + 1] = 1

                    qid_ids.append(qid)
                    pass_ids.append(pass_tensor)
                    subject_ids.append(torch.tensor(subjects, dtype=torch.long))
                    sent_encoders.append(new_sent_encoder)
                    token_type_ids.append(token_type_id)

            if len(qid_ids) == 0:
                # print('len(qid_list)==0:')
                subject_ids = torch.zeros(1, 2).long().to(sent_encoder.device)
                qid_tensor = torch.tensor([-1], dtype=torch.long).to(sent_encoder.device)
                po_tensor = torch.zeros(1, sent_encoder.size(1)).long().to(sent_encoder.device)
                return qid_tensor, subject_ids, po_tensor

            qids = torch.cat(qid_ids).to(sent_encoder.device)
            pass_ids = torch.cat(pass_ids).to(sent_encoder.device)
            sent_encoders = torch.cat(sent_encoders).to(sent_encoder.device)
            # token_type_ids = torch.cat(token_type_ids).to(bert_encoder.device)
            subject_ids = torch.cat(subject_ids).to(sent_encoder.device)

            flag = False
            split_heads = 1024

            sent_encoders_ = torch.split(sent_encoders, split_heads, dim=0)
            pass_ids_ = torch.split(pass_ids, split_heads, dim=0)
            # token_type_ids_ = torch.split(token_type_ids, split_heads, dim=0)
            subject_encoder_ = torch.split(subject_ids, split_heads, dim=0)
            po_preds = list()
            for i in range(len(subject_encoder_)):
                sent_encoders = sent_encoders_[i]
                # token_type_ids = token_type_ids_[i]
                pass_ids = pass_ids_[i]
                subject_encoder = subject_encoder_[i]

                if sent_encoders.size(0) == 1:
                    flag = True
                    sent_encoders = sent_encoders.expand(2, sent_encoders.size(1), sent_encoders.size(2))
                    subject_encoder = subject_encoder.expand(2, subject_encoder.size(1))
                    pass_ids = pass_ids.expand(2, pass_ids.size(1))
                sub_start_encoder = batch_gather(sent_encoders, subject_encoder[:, 0])
                sub_end_encoder = batch_gather(sent_encoders, subject_encoder[:, 1])
                subject = torch.cat([sub_start_encoder, sub_end_encoder], 1)
                context_encoder = self.LayerNorm(sent_encoders, subject)
                context_encoder = self.transformer_encoder(context_encoder.transpose(1, 0),
                                                           src_key_padding_mask=pass_ids.eq(0)).transpose(0, 1)
                emission = self.emission(context_encoder)
                po_pred = self.crf.decode(emissions=emission, mask=(pass_ids != 0))
                max_len = pass_ids.size(1)
                temp_tag = copy.deepcopy(po_pred)
                for line in temp_tag:
                    line.extend([0] * (max_len - len(line)))
                # TODO:check
                po_pred = torch.tensor(temp_tag).to(emission.device)
                if flag:
                    po_pred = po_pred[1, :].unsqueeze(0)

                po_preds.append(po_pred)
            po_tensor = torch.cat(po_preds).to(qids.device)
            # print(subject_ids.device)
            # print(po_tensor.device)
            # print(qids.shape)
            # print(subject_ids.shape)
            # print(po_tensor.shape)
            return qids, subject_ids, po_tensor
    def forward(self,
                q_ids=None,
                passage_ids=None,
                segment_ids=None,
                token_type_ids=None,
                subject_ids=None,
                subject_labels=None,
                object_labels=None,
                eval_file=None,
                is_eval=False):
        mask = (passage_ids != 0).float()
        bert_encoder_ = self.bert(passage_ids,
                                  token_type_ids=segment_ids,
                                  attention_mask=mask,
                                  output_all_encoded_layers=True)[0][-4:]
        bert_encoder_ = torch.cat([m.unsqueeze(0)
                                   for m in bert_encoder_]).mean(0)
        bert_encoder = self.lstm_encoder(bert_encoder_, mask=passage_ids.eq(0))
        if not is_eval:
            sub_start_encoder = batch_gather(bert_encoder, subject_ids[:, 0])
            sub_end_encoder = batch_gather(bert_encoder, subject_ids[:, 1])
            subject = torch.cat([sub_start_encoder, sub_end_encoder], 1)
            context_encoder = self.LayerNorm(bert_encoder, subject)

            sub_preds = self.subject_dense(bert_encoder)
            po_preds = self.po_dense(context_encoder).reshape(
                passage_ids.size(0), -1, self.classes_num, 2)

            subject_loss = self.loss_fct(sub_preds, subject_labels)
            subject_loss = subject_loss.mean(2)
            subject_loss = torch.sum(subject_loss * mask.float()) / torch.sum(
                mask.float())

            po_loss = self.loss_fct(po_preds, object_labels)
            po_loss = torch.sum(po_loss.mean(3), 2)
            po_loss = torch.sum(po_loss * mask.float()) / torch.sum(
                mask.float())

            loss = subject_loss + po_loss

            return loss

        else:

            subject_preds = nn.Sigmoid()(self.subject_dense(bert_encoder))
            answer_list = list()
            for qid, sub_pred in zip(q_ids.cpu().numpy(),
                                     subject_preds.cpu().numpy()):
                context = eval_file[qid].bert_tokens
                start = np.where(sub_pred[:, 0] > 0.5)[0]
                end = np.where(sub_pred[:, 1] > 0.5)[0]
                subjects = []
                for i in start:
                    j = end[end >= i]
                    if i == 0 or i > len(context) - 2:
                        continue

                    if len(j) > 0:
                        j = j[0]
                        if j > len(context) - 2:
                            continue
                        subjects.append((i, j))

                answer_list.append(subjects)

            qid_ids, bert_encoders, pass_ids, subject_ids, token_type_ids = [], [], [], [], []
            for i, subjects in enumerate(answer_list):
                if subjects:
                    qid = q_ids[i].unsqueeze(0).expand(len(subjects))
                    pass_tensor = passage_ids[i, :].unsqueeze(0).expand(
                        len(subjects), passage_ids.size(1))
                    new_bert_encoder = bert_encoder[i, :, :].unsqueeze(
                        0).expand(len(subjects), bert_encoder.size(1),
                                  bert_encoder.size(2))

                    token_type_id = torch.zeros(
                        (len(subjects), passage_ids.size(1)), dtype=torch.long)
                    for index, (start, end) in enumerate(subjects):
                        token_type_id[index, start:end + 1] = 1

                    qid_ids.append(qid)
                    pass_ids.append(pass_tensor)
                    subject_ids.append(torch.tensor(subjects,
                                                    dtype=torch.long))
                    bert_encoders.append(new_bert_encoder)
                    token_type_ids.append(token_type_id)

            if len(qid_ids) == 0:
                subject_ids = torch.zeros(1, 2).long().to(bert_encoder.device)
                qid_tensor = torch.tensor([-1], dtype=torch.long).to(
                    bert_encoder.device)
                po_tensor = torch.zeros(1, bert_encoder.size(1)).long().to(
                    bert_encoder.device)
                return qid_tensor, subject_ids, po_tensor

            qids = torch.cat(qid_ids).to(bert_encoder.device)
            pass_ids = torch.cat(pass_ids).to(bert_encoder.device)
            bert_encoders = torch.cat(bert_encoders).to(bert_encoder.device)
            subject_ids = torch.cat(subject_ids).to(bert_encoder.device)

            flag = False
            split_heads = 1024

            bert_encoders_ = torch.split(bert_encoders, split_heads, dim=0)
            pass_ids_ = torch.split(pass_ids, split_heads, dim=0)
            subject_encoder_ = torch.split(subject_ids, split_heads, dim=0)

            po_preds = list()
            for i in range(len(bert_encoders_)):
                bert_encoders = bert_encoders_[i]
                pass_ids = pass_ids_[i]
                subject_encoder = subject_encoder_[i]

                if bert_encoders.size(0) == 1:
                    flag = True
                    bert_encoders = bert_encoders.expand(
                        2, bert_encoders.size(1), bert_encoders.size(2))
                    subject_encoder = subject_encoder.expand(
                        2, subject_encoder.size(1))
                sub_start_encoder = batch_gather(bert_encoders,
                                                 subject_encoder[:, 0])
                sub_end_encoder = batch_gather(bert_encoders,
                                               subject_encoder[:, 1])
                subject = torch.cat([sub_start_encoder, sub_end_encoder], 1)
                context_encoder = self.LayerNorm(bert_encoders, subject)

                po_pred = self.po_dense(context_encoder).reshape(
                    subject_encoder.size(0), -1, self.classes_num, 2)

                if flag:
                    po_pred = po_pred[1, :, :, :].unsqueeze(0)

                po_preds.append(po_pred)

            po_tensor = torch.cat(po_preds).to(qids.device)
            po_tensor = nn.Sigmoid()(po_tensor)
            return qids, subject_ids, po_tensor
Ejemplo n.º 3
0
    def forward(self, q_ids=None, char_ids=None, word_ids=None, token_type_ids=None, subject_ids=None,
                subject_labels=None,
                object_labels=None, eval_file=None,
                is_eval=False):

        mask = char_ids != 0

        seq_mask = char_ids.eq(0)

        char_emb = self.char_emb(char_ids)
        word_emb = self.word_convert_char(self.word_emb(word_ids))
        # word_emb = self.word_emb(word_ids)
        emb = char_emb + word_emb
        # emb = char_emb
        # subject_encoder = sent_encoder + self.token_entity_emb(token_type_id)
        sent_encoder = self.first_sentence_encoder(emb, seq_mask)

        if not is_eval:
            # subject_encoder = self.token_entity_emb(token_type_ids)
            # context_encoder = bert_encoder + subject_encoder

            sub_start_encoder = batch_gather(sent_encoder, subject_ids[:, 0])
            sub_end_encoder = batch_gather(sent_encoder, subject_ids[:, 1])
            subject = torch.cat([sub_start_encoder, sub_end_encoder], 1)
            context_encoder = self.LayerNorm(sent_encoder, subject)
            context_encoder = self.transformer_encoder(context_encoder.transpose(1, 0),
                                                       src_key_padding_mask=seq_mask).transpose(0, 1)

            sub_preds = self.subject_dense(sent_encoder)
            po_preds = self.po_dense(context_encoder).reshape(char_ids.size(0), -1, self.classes_num, 2)

            subject_loss = self.loss_fct(sub_preds, subject_labels)
            subject_loss = subject_loss.mean(2)
            subject_loss = torch.sum(subject_loss * mask.float()) / torch.sum(mask.float())

            po_loss = self.loss_fct(po_preds, object_labels)
            po_loss = torch.sum(po_loss.mean(3), 2)
            po_loss = torch.sum(po_loss * mask.float()) / torch.sum(mask.float())

            loss = subject_loss + po_loss

            return loss

        else:

            subject_preds = nn.Sigmoid()(self.subject_dense(sent_encoder))
            answer_list = list()
            for qid, sub_pred in zip(q_ids.cpu().numpy(),
                                     subject_preds.cpu().numpy()):
                context = eval_file[qid].context
                start = np.where(sub_pred[:, 0] > 0.5)[0]
                end = np.where(sub_pred[:, 1] > 0.4)[0]
                subjects = []
                for i in start:
                    j = end[end >= i]
                    if i >= len(context):
                        continue
                    if len(j) > 0:
                        j = j[0]
                        if j >= len(context):
                            continue
                        subjects.append((i, j))

                answer_list.append(subjects)

            qid_ids, sent_encoders, pass_ids, subject_ids, token_type_ids = [], [], [], [], []
            for i, subjects in enumerate(answer_list):
                if subjects:
                    qid = q_ids[i].unsqueeze(0).expand(len(subjects))
                    pass_tensor = char_ids[i, :].unsqueeze(0).expand(len(subjects), char_ids.size(1))
                    new_sent_encoder = sent_encoder[i, :, :].unsqueeze(0).expand(len(subjects), sent_encoder.size(1),
                                                                                 sent_encoder.size(2))

                    token_type_id = torch.zeros((len(subjects), char_ids.size(1)), dtype=torch.long)
                    for index, (start, end) in enumerate(subjects):
                        token_type_id[index, start:end + 1] = 1

                    qid_ids.append(qid)
                    pass_ids.append(pass_tensor)
                    subject_ids.append(torch.tensor(subjects, dtype=torch.long))
                    sent_encoders.append(new_sent_encoder)
                    token_type_ids.append(token_type_id)

            if len(qid_ids) == 0:
                # print('len(qid_list)==0:')
                qid_tensor = torch.tensor([-1, -1], dtype=torch.long).to(sent_encoder.device)
                return qid_tensor, qid_tensor, qid_tensor

            # print('len(qid_list)!=========================0:')
            qids = torch.cat(qid_ids).to(sent_encoder.device)
            pass_ids = torch.cat(pass_ids).to(sent_encoder.device)
            sent_encoders = torch.cat(sent_encoders).to(sent_encoder.device)
            # token_type_ids = torch.cat(token_type_ids).to(bert_encoder.device)
            subject_ids = torch.cat(subject_ids).to(sent_encoder.device)

            flag = False
            split_heads = 1024

            sent_encoders_ = torch.split(sent_encoders, split_heads, dim=0)
            pass_ids_ = torch.split(pass_ids, split_heads, dim=0)
            # token_type_ids_ = torch.split(token_type_ids, split_heads, dim=0)
            subject_encoder_ = torch.split(subject_ids, split_heads, dim=0)
            # print('len(qid_list)!=========================1:')
            po_preds = list()
            for i in range(len(subject_encoder_)):
                sent_encoders = sent_encoders_[i]
                # token_type_ids = token_type_ids_[i]
                pass_ids = pass_ids_[i]
                subject_encoder = subject_encoder_[i]

                if sent_encoders.size(0) == 1:
                    flag = True
                    # print('flag = True**********')
                    sent_encoders = sent_encoders.expand(2, sent_encoders.size(1), sent_encoders.size(2))
                    subject_encoder = subject_encoder.expand(2, subject_encoder.size(1))
                    pass_ids = pass_ids.expand(2, pass_ids.size(1))
                # print('len(qid_list)!=========================2:')
                sub_start_encoder = batch_gather(sent_encoders, subject_encoder[:, 0])
                sub_end_encoder = batch_gather(sent_encoders, subject_encoder[:, 1])
                subject = torch.cat([sub_start_encoder, sub_end_encoder], 1)
                context_encoder = self.LayerNorm(sent_encoders, subject)
                context_encoder = self.transformer_encoder(context_encoder.transpose(1, 0),
                                                           src_key_padding_mask=pass_ids.eq(0)).transpose(0, 1)
                # print('len(qid_list)!=========================3')
                # context_encoder = self.LayerNorm(context_encoder)
                po_pred = self.po_dense(context_encoder).reshape(subject_encoder.size(0), -1, self.classes_num, 2)

                if flag:
                    po_pred = po_pred[1, :, :, :].unsqueeze(0)

                po_preds.append(po_pred)

            po_tensor = torch.cat(po_preds).to(qids.device)
            po_tensor = nn.Sigmoid()(po_tensor)
            return qids, subject_ids, po_tensor