Exemple #1
0
class ERENet(nn.Module):
    """
    ERENet : entity relation jointed extraction
    """

    def __init__(self, args, word_emb, ent_conf, spo_conf):
        print('mhs using only char2v+w2v mixed  and word_emb is freeze ')
        super(ERENet, self).__init__()

        self.max_len = args.max_len

        self.word_emb = nn.Embedding.from_pretrained(torch.tensor(word_emb, dtype=torch.float32), freeze=True,
                                                     padding_idx=0)
        self.char_emb = nn.Embedding(num_embeddings=args.char_vocab_size, embedding_dim=args.char_emb_size,
                                     padding_idx=0)

        self.word_convert_char = nn.Linear(args.word_emb_size, args.char_emb_size, bias=False)

        self.classes_num = len(spo_conf)

        self.first_sentence_encoder = SentenceEncoder(args, args.char_emb_size)
        # self.second_sentence_encoder = SentenceEncoder(args, args.hidden_size)
        # self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=config.hidden_size,
        #                                      padding_idx=0)
        self.encoder_layer = TransformerEncoderLayer(args.hidden_size * 2, nhead=3)
        self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=1)
        self.LayerNorm = ConditionalLayerNorm(args.hidden_size * 2, eps=1e-12)
        # self.subject_dense = nn.Linear(args.hidden_size * 2, 2)

        self.ent_emission = nn.Linear(args.hidden_size * 2, len(ent_conf))
        self.ent_crf = CRF(len(ent_conf), batch_first=True)
        self.emission = nn.Linear(args.hidden_size * 2, len(spo_conf))
        self.crf = CRF(len(spo_conf), batch_first=True)
        self.loss_fct = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, q_ids=None, char_ids=None, word_ids=None, token_type_ids=None, subject_ids=None,
                subject_labels=None,
                object_labels=None, eval_file=None,
                is_eval=False):

        mask = char_ids != 0

        seq_mask = char_ids.eq(0)

        char_emb = self.char_emb(char_ids)
        word_emb = self.word_convert_char(self.word_emb(word_ids))
        # word_emb = self.word_emb(word_ids)
        emb = char_emb + word_emb
        # emb = char_emb
        # subject_encoder = sent_encoder + self.token_entity_emb(token_type_id)
        sent_encoder = self.first_sentence_encoder(emb, seq_mask)
        ent_emission = self.ent_emission(sent_encoder)
        if not is_eval:
            # subject_encoder = self.token_entity_emb(token_type_ids)
            # context_encoder = bert_encoder + subject_encoder

            sub_start_encoder = batch_gather(sent_encoder, subject_ids[:, 0])
            sub_end_encoder = batch_gather(sent_encoder, subject_ids[:, 1])
            subject = torch.cat([sub_start_encoder, sub_end_encoder], 1)
            context_encoder = self.LayerNorm(sent_encoder, subject)
            context_encoder = self.transformer_encoder(context_encoder.transpose(1, 0),
                                                       src_key_padding_mask=seq_mask).transpose(0, 1)

            ent_loss = -self.ent_crf(ent_emission, subject_labels, mask=mask, reduction='mean')

            emission = self.emission(context_encoder)
            po_loss = -self.crf(emission, object_labels, mask=mask, reduction='mean')
            loss = ent_loss + po_loss

            return loss

        else:
            subject_preds = self.ent_crf.decode(emissions=ent_emission, mask=mask)
            answer_list = list()
            for qid, sub_pred in zip(q_ids.cpu().numpy(), subject_preds):
                seq_len = min(len(eval_file[qid].context), self.max_len)
                tag_list = list()
                j = 0
                while j < seq_len:
                    end = j
                    flag = True

                    if sub_pred[j] == 1:
                        start = j
                        for k in range(start + 1, seq_len):
                            if sub_pred[k] != sub_pred[start] + 1:
                                end = k - 1
                                flag = False
                                break
                        if flag:
                            end = seq_len - 1
                        tag_list.append((start, end))
                    j = end + 1

                answer_list.append(tag_list)

            qid_ids, sent_encoders, pass_ids, subject_ids, token_type_ids = [], [], [], [], []
            for i, subjects in enumerate(answer_list):
                if subjects:
                    qid = q_ids[i].unsqueeze(0).expand(len(subjects))
                    pass_tensor = char_ids[i, :].unsqueeze(0).expand(len(subjects), char_ids.size(1))
                    new_sent_encoder = sent_encoder[i, :, :].unsqueeze(0).expand(len(subjects), sent_encoder.size(1),
                                                                                 sent_encoder.size(2))

                    token_type_id = torch.zeros((len(subjects), char_ids.size(1)), dtype=torch.long)
                    for index, (start, end) in enumerate(subjects):
                        token_type_id[index, start:end + 1] = 1

                    qid_ids.append(qid)
                    pass_ids.append(pass_tensor)
                    subject_ids.append(torch.tensor(subjects, dtype=torch.long))
                    sent_encoders.append(new_sent_encoder)
                    token_type_ids.append(token_type_id)

            if len(qid_ids) == 0:
                # print('len(qid_list)==0:')
                subject_ids = torch.zeros(1, 2).long().to(sent_encoder.device)
                qid_tensor = torch.tensor([-1], dtype=torch.long).to(sent_encoder.device)
                po_tensor = torch.zeros(1, sent_encoder.size(1)).long().to(sent_encoder.device)
                return qid_tensor, subject_ids, po_tensor

            qids = torch.cat(qid_ids).to(sent_encoder.device)
            pass_ids = torch.cat(pass_ids).to(sent_encoder.device)
            sent_encoders = torch.cat(sent_encoders).to(sent_encoder.device)
            # token_type_ids = torch.cat(token_type_ids).to(bert_encoder.device)
            subject_ids = torch.cat(subject_ids).to(sent_encoder.device)

            flag = False
            split_heads = 1024

            sent_encoders_ = torch.split(sent_encoders, split_heads, dim=0)
            pass_ids_ = torch.split(pass_ids, split_heads, dim=0)
            # token_type_ids_ = torch.split(token_type_ids, split_heads, dim=0)
            subject_encoder_ = torch.split(subject_ids, split_heads, dim=0)
            po_preds = list()
            for i in range(len(subject_encoder_)):
                sent_encoders = sent_encoders_[i]
                # token_type_ids = token_type_ids_[i]
                pass_ids = pass_ids_[i]
                subject_encoder = subject_encoder_[i]

                if sent_encoders.size(0) == 1:
                    flag = True
                    sent_encoders = sent_encoders.expand(2, sent_encoders.size(1), sent_encoders.size(2))
                    subject_encoder = subject_encoder.expand(2, subject_encoder.size(1))
                    pass_ids = pass_ids.expand(2, pass_ids.size(1))
                sub_start_encoder = batch_gather(sent_encoders, subject_encoder[:, 0])
                sub_end_encoder = batch_gather(sent_encoders, subject_encoder[:, 1])
                subject = torch.cat([sub_start_encoder, sub_end_encoder], 1)
                context_encoder = self.LayerNorm(sent_encoders, subject)
                context_encoder = self.transformer_encoder(context_encoder.transpose(1, 0),
                                                           src_key_padding_mask=pass_ids.eq(0)).transpose(0, 1)
                emission = self.emission(context_encoder)
                po_pred = self.crf.decode(emissions=emission, mask=(pass_ids != 0))
                max_len = pass_ids.size(1)
                temp_tag = copy.deepcopy(po_pred)
                for line in temp_tag:
                    line.extend([0] * (max_len - len(line)))
                # TODO:check
                po_pred = torch.tensor(temp_tag).to(emission.device)
                if flag:
                    po_pred = po_pred[1, :].unsqueeze(0)

                po_preds.append(po_pred)
            po_tensor = torch.cat(po_preds).to(qids.device)
            # print(subject_ids.device)
            # print(po_tensor.device)
            # print(qids.shape)
            # print(subject_ids.shape)
            # print(po_tensor.shape)
            return qids, subject_ids, po_tensor
Exemple #2
0
class NERNet(nn.Module):
    """
        NERNet : Lstm+CRF
    """
    def __init__(self, args, model_conf):
        super(NERNet, self).__init__()
        char_emb = model_conf['char_emb']
        bichar_emb = model_conf['bichar_emb']
        embed_size = args.char_emb_dim
        if char_emb is not None:
            # self.char_emb = nn.Embedding.from_pretrained(char_emb, freeze=False, padding_idx=0)

            self.char_emb = nn.Embedding(num_embeddings=char_emb.shape[0],
                                         embedding_dim=char_emb.shape[1],
                                         padding_idx=0,
                                         _weight=char_emb)
            self.char_emb.weight.requires_grad = True
            embed_size = char_emb.size()[1]
        else:
            vocab_size = len(model_conf['char_vocab'])
            self.char_emb = nn.Embedding(num_embeddings=vocab_size,
                                         embedding_dim=args.char_emb_dim,
                                         padding_idx=0)
        self.bichar_emb = None
        if bichar_emb is not None:
            # self.bichar_emb = nn.Embedding.from_pretrained(bichar_emb, freeze=False, padding_idx=0)
            self.bichar_emb = nn.Embedding(num_embeddings=bichar_emb.shape[0],
                                           embedding_dim=bichar_emb.shape[1],
                                           padding_idx=0,
                                           _weight=bichar_emb)
            self.bichar_emb.weight.requires_grad = True

            embed_size += bichar_emb.size()[1]

        self.drop = nn.Dropout(p=0.5)
        # self.sentence_encoder = SentenceEncoder(args, embed_size)
        self.sentence_encoder = nn.LSTM(embed_size,
                                        args.hidden_size,
                                        num_layers=1,
                                        batch_first=True,
                                        bidirectional=True)
        self.emission = nn.Linear(args.hidden_size * 2,
                                  len(model_conf['entity_type']))
        self.crf = CRF(len(model_conf['entity_type']), batch_first=True)

    def forward(self, char_id, bichar_id, label_id=None, is_eval=False):
        # use anti-mask for answers-locator
        mask = char_id.eq(0)
        chars = self.char_emb(char_id)

        if self.bichar_emb is not None:
            bichars = self.bichar_emb(bichar_id)
            chars = torch.cat([chars, bichars], dim=-1)
        chars = self.drop(chars)

        # sen_encoded = self.sentence_encoder(chars, mask)
        sen_encoded, _ = self.sentence_encoder(chars)
        sen_encoded = self.drop(sen_encoded)

        bio_mask = char_id != 0
        emission = self.emission(sen_encoded)
        emission = F.log_softmax(emission, dim=-1)

        if not is_eval:
            crf_loss = -self.crf(
                emission, label_id, mask=bio_mask, reduction='mean')
            return crf_loss
        else:
            pred = self.crf.decode(emissions=emission, mask=bio_mask)

            # TODO:check
            max_len = char_id.size(1)
            temp_tag = copy.deepcopy(pred)
            for line in temp_tag:
                line.extend([0] * (max_len - len(line)))
            ent_pre = torch.tensor(temp_tag).to(emission.device)
            return ent_pre
Exemple #3
0
class ERENet(nn.Module):
    """
        ERENet : entity relation extraction
    """
    def __init__(self, args, word_emb):
        super(ERENet, self).__init__()
        print('mhs with w2v')

        if args.activation.lower() == 'relu':
            self.activation = nn.ReLU()
        elif args.activation.lower() == 'tanh':
            self.activation = nn.Tanh()

        self.word_emb = nn.Embedding.from_pretrained(torch.tensor(
            word_emb, dtype=torch.float32),
                                                     freeze=True,
                                                     padding_idx=0)
        self.word_convert_char = nn.Linear(args.word_emb_size,
                                           args.char_emb_size,
                                           bias=False)
        self.char_emb = nn.Embedding(num_embeddings=args.char_vocab_size,
                                     embedding_dim=args.char_emb_size,
                                     padding_idx=0)
        self.rel_emb = nn.Embedding(num_embeddings=len(BAIDU_RELATION),
                                    embedding_dim=args.rel_emb_size)
        self.ent_emb = nn.Embedding(num_embeddings=len(BAIDU_ENTITY),
                                    embedding_dim=args.ent_emb_size)

        self.sentence_encoder = SentenceEncoder(args, args.char_emb_size)

        self.emission = nn.Linear(args.hidden_size * 2, len(BAIDU_ENTITY))

        self.crf = CRF(len(BAIDU_ENTITY), batch_first=True)

        self.selection_u = nn.Linear(2 * args.hidden_size + args.ent_emb_size,
                                     args.rel_emb_size)
        self.selection_v = nn.Linear(2 * args.hidden_size + args.ent_emb_size,
                                     args.rel_emb_size)
        self.selection_uv = nn.Linear(2 * args.rel_emb_size, args.rel_emb_size)

    def forward(self,
                char_ids=None,
                word_ids=None,
                label_ids=None,
                spo_ids=None,
                is_eval=False):

        # Entity Extraction
        mask = char_ids.eq(0)
        char_emb = self.char_emb(char_ids)
        word_emb = self.word_convert_char(self.word_emb(word_ids))
        emb = char_emb + word_emb
        sent_encoder = self.sentence_encoder(emb, mask)

        bio_mask = char_ids != 0
        emission = self.emission(sent_encoder)
        # TODO:check
        ent_pre = self.entity_decoder(bio_mask,
                                      emission,
                                      max_len=sent_encoder.size(1))

        # Relation Extraction
        if is_eval:
            ent_encoder = self.ent_emb(ent_pre)
        else:
            ent_encoder = self.ent_emb(label_ids)

        rel_encoder = torch.cat((sent_encoder, ent_encoder), dim=2)

        B, L, H = rel_encoder.size()
        u = self.activation(self.selection_u(rel_encoder)).unsqueeze(1).expand(
            B, L, L, -1)
        v = self.activation(self.selection_v(rel_encoder)).unsqueeze(2).expand(
            B, L, L, -1)
        uv = self.activation(self.selection_uv(torch.cat((u, v), dim=-1)))

        selection_logits = torch.einsum('bijh,rh->birj', uv,
                                        self.rel_emb.weight)

        if is_eval:
            return ent_pre, selection_logits
        else:
            crf_loss = -self.crf(
                emission, label_ids, mask=bio_mask, reduction='mean')
            selection_loss = self.masked_BCEloss(bio_mask, selection_logits,
                                                 spo_ids)

            loss = crf_loss + selection_loss

            return loss

    def masked_BCEloss(self, mask, selection_logits, selection_gold):

        # batch x seq x rel x seq
        selection_mask = (mask.unsqueeze(2) *
                          mask.unsqueeze(1)).unsqueeze(2).expand(
                              -1, -1, len(BAIDU_RELATION), -1)

        selection_loss = F.binary_cross_entropy_with_logits(selection_logits,
                                                            selection_gold,
                                                            reduction='none')
        selection_loss = selection_loss.masked_select(selection_mask).sum()
        selection_loss /= mask.sum()
        return selection_loss

    def entity_decoder(self, bio_mask, emission, max_len):
        decoded_tag = self.crf.decode(emissions=emission, mask=bio_mask)

        temp_tag = copy.deepcopy(decoded_tag)
        for line in temp_tag:
            line.extend([0] * (max_len - len(line)))
        # TODO:check
        ent_pre = torch.tensor(temp_tag).to(emission.device)
        # print('entity predict embedding device is {}'.format(ent_pre.device))
        return ent_pre