Пример #1
0
    def test_case2(self):
        # 测试CRF是否正常work。
        import json
        import torch
        from fastNLP import seq_len_to_mask

        with open('tests/data_for_tests/modules/decoder/crf.json', 'r') as f:
            data = json.load(f)

        bio_logits = torch.FloatTensor(data['bio_logits'])
        bio_scores = data['bio_scores']
        bio_path = data['bio_path']
        bio_trans_m = torch.FloatTensor(data['bio_trans_m'])
        bio_seq_lens = torch.LongTensor(data['bio_seq_lens'])

        bmes_logits = torch.FloatTensor(data['bmes_logits'])
        bmes_scores = data['bmes_scores']
        bmes_path = data['bmes_path']
        bmes_trans_m = torch.FloatTensor(data['bmes_trans_m'])
        bmes_seq_lens = torch.LongTensor(data['bmes_seq_lens'])

        labels = ['O']
        for label in ['X', 'Y']:
            for tag in 'BI':
                labels.append('{}-{}'.format(tag, label))
        id2label = {idx: label for idx, label in enumerate(labels)}
        num_tags = len(id2label)

        mask = seq_len_to_mask(bio_seq_lens)

        from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions
        fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
                                                                                                     include_start_end=True))
        fast_CRF.trans_m.data = bio_trans_m
        fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True)
        # score equal
        self.assertListEqual(bio_scores, [round(s, 4) for s in fast_res[1].tolist()])
        # seq equal
        self.assertListEqual(bio_path, fast_res[0])

        labels = []
        for label in ['X', 'Y']:
            for tag in 'BMES':
                labels.append('{}-{}'.format(tag, label))
        id2label = {idx: label for idx, label in enumerate(labels)}
        num_tags = len(id2label)

        mask = seq_len_to_mask(bmes_seq_lens)

        from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions
        fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
                                                                                                     encoding_type='BMES',
                                                                                                     include_start_end=True))
        fast_CRF.trans_m.data = bmes_trans_m
        fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True)
        # score equal
        self.assertListEqual(bmes_scores, [round(s, 4) for s in fast_res[1].tolist()])
        # seq equal
        self.assertListEqual(bmes_path, fast_res[0])
Пример #2
0
    def evaluate(self, pred, target, seq_len=None):
        """
        evaluate函数将针对一个批次的预测结果做评价指标的累计

        :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
                torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
        :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
                torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len])
        :param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]).
                如果mask也被传进来的话seq_len会被忽略.

        """
        # TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value
        if not isinstance(pred, torch.Tensor):
            raise TypeError(
                f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
                f"got {type(pred)}.")
        if not isinstance(target, torch.Tensor):
            raise TypeError(
                f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
                f"got {type(target)}.")

        if seq_len is not None and not isinstance(seq_len, torch.Tensor):
            raise TypeError(
                f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
                f"got {type(seq_len)}.")

        if seq_len is not None and target.dim() > 1:
            max_len = target.size(1)
            masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len)
        else:
            masks = torch.ones_like(target).long().to(target.device)
        masks = masks.eq(False)

        if pred.dim() == target.dim():
            pass
        elif pred.dim() == target.dim() + 1:
            pred = pred.argmax(dim=-1)
            if seq_len is None and target.dim() > 1:
                logger.warning(
                    "You are not passing `seq_len` to exclude pad when calculate accuracy."
                )
        else:
            raise RuntimeError(
                f"In {_get_func_signature(self.evaluate)}, when pred have "
                f"size:{pred.size()}, target should have size: {pred.size()} or "
                f"{pred.size()[:-1]}, got {target.size()}.")

        target_idxes = set(target.reshape(-1).tolist())
        target = target.to(pred)
        for target_idx in target_idxes:
            self._tp[target_idx] += torch.sum(
                (pred == target_idx).long().masked_fill(
                    target != target_idx, 0).masked_fill(masks, 0)).item()
            self._fp[target_idx] += torch.sum(
                (pred != target_idx).long().masked_fill(
                    target != target_idx, 0).masked_fill(masks, 0)).item()
            self._fn[target_idx] += torch.sum(
                (pred == target_idx).long().masked_fill(
                    target == target_idx, 0).masked_fill(masks, 0)).item()
Пример #3
0
    def forward(self, chars, bigrams, seq_len, target):
        embed_char = self.char_embed(chars)

        if self.use_bigram:

            embed_bigram = self.bigram_embed(bigrams)

            embedding = torch.cat([embed_char, embed_bigram], dim=-1)
        else:

            embedding = embed_char

        embedding = self.embed_dropout(embedding)

        encoded_h, encoded_c = self.encoder(embedding, seq_len)

        encoded_h = self.output_dropout(encoded_h)

        pred = self.output(encoded_h)

        mask = seq_len_to_mask(seq_len)

        # pred = self.crf(pred)

        # batch_size, sent_len = pred.shape[0], pred.shape[1]
        # loss = self.loss_func(pred.reshape(batch_size * sent_len, -1), target.reshape(batch_size * sent_len))
        if self.training:
            loss = self.crf(pred, target, mask)
            return {'loss': loss}
        else:
            pred, path = self.crf.viterbi_decode(pred, mask)
            return {'pred': pred}
Пример #4
0
    def test_case(self):
        vocab = Vocabulary().add_word_lst("This is a test .".split())
        vocab.add_word_lst("Another test !".split())
        embed = StaticEmbedding(vocab,
                                model_dir_or_name=None,
                                embedding_dim=10)

        encoder_output = torch.randn(2, 3, 10)
        tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
        src_seq_len = torch.LongTensor([3, 2])
        encoder_mask = seq_len_to_mask(src_seq_len)

        for flag in [True, False]:
            for attention in [True, False]:
                with self.subTest(bind_decoder_input_output_embed=flag,
                                  attention=attention):
                    decoder = LSTMSeq2SeqDecoder(
                        embed=embed,
                        num_layers=2,
                        hidden_size=10,
                        dropout=0.3,
                        bind_decoder_input_output_embed=flag,
                        attention=attention)
                    state = decoder.init_state(encoder_output, encoder_mask)
                    output = decoder(tgt_words_idx, state)
                    self.assertEqual(tuple(output.size()), (2, 4, len(vocab)))
Пример #5
0
    def test_case(self):
        vocab = Vocabulary().add_word_lst("This is a test .".split())
        vocab.add_word_lst("Another test !".split())
        embed = StaticEmbedding(vocab, embedding_dim=10)

        encoder_output = torch.randn(2, 3, 10)
        src_seq_len = torch.LongTensor([3, 2])
        encoder_mask = seq_len_to_mask(src_seq_len)

        for flag in [True, False]:
            with self.subTest(bind_decoder_input_output_embed=flag):
                decoder = TransformerSeq2SeqDecoder(
                    embed=embed,
                    pos_embed=None,
                    d_model=10,
                    num_layers=2,
                    n_head=5,
                    dim_ff=20,
                    dropout=0.1,
                    bind_decoder_input_output_embed=True)
                state = decoder.init_state(encoder_output, encoder_mask)
                output = decoder(tokens=torch.randint(0,
                                                      len(vocab),
                                                      size=(2, 4)),
                                 state=state)
                self.assertEqual(output.size(), (2, 4, len(vocab)))
Пример #6
0
    def _forward(self, words, seq_len, target=None):
        words = self.embedding(words)
        outputs, _ = self.lstm(words, seq_len)
        self.dropout(outputs)

        logits = F.log_softmax(self.fc(outputs), dim=-1)

        if target is not None:
            loss = self.crf(logits, target,
                            seq_len_to_mask(seq_len,
                                            max_len=logits.size(1))).mean()
            return {Const.LOSS: loss}
        else:
            pred, _ = self.crf.viterbi_decode(
                logits, seq_len_to_mask(seq_len, max_len=logits.size(1)))
            return {Const.OUTPUT: pred}
    def forward(self,
                lattice,
                bigrams,
                seq_len,
                lex_num,
                pos_s,
                pos_e,
                target,
                chars_target=None):
        batch_size = lattice.size(0)
        max_seq_len_and_lex_num = lattice.size(1)
        max_seq_len = bigrams.size(1)

        words = lattice[:, :max_seq_len]
        mask = seq_len_to_mask(seq_len).bool()
        words.masked_fill_((~mask), self.vocabs['lattice'].padding_idx)
        encoded = self.bert_embedding(words)

        if self.after_bert == 'lstm':
            encoded, _ = self.lstm(encoded, seq_len)
            encoded = self.dropout(encoded)

        pred = self.output(encoded)

        if self.training:
            loss = self.crf(pred, target, mask).mean(dim=0)
            return {'loss': loss}
        else:
            pred, path = self.crf.viterbi_decode(pred, mask)
            result = {'pred': pred}

            return result
Пример #8
0
    def test_case3(self):
        # 测试crf的loss不会出现负数
        import torch
        from fastNLP.modules.decoder.crf import ConditionalRandomField
        from fastNLP.core.utils import seq_len_to_mask
        from torch import optim
        from torch import nn

        num_tags, include_start_end_trans = 4, True
        num_samples = 4
        lengths = torch.randint(3, 50, size=(num_samples, )).long()
        max_len = lengths.max()
        tags = torch.randint(num_tags, size=(num_samples, max_len))
        masks = seq_len_to_mask(lengths)
        feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags))
        crf = ConditionalRandomField(num_tags, include_start_end_trans)
        optimizer = optim.SGD(
            [param
             for param in crf.parameters() if param.requires_grad] + [feats],
            lr=0.1)
        for _ in range(10):
            loss = crf(feats, tags, masks).mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if _ % 1000 == 0:
                print(loss)
            self.assertGreater(loss.item(), 0,
                               "CRF loss cannot be less than 0.")
    def forward(self, chars, bigrams, seq_len, target, chars_target=None):
        # print('**self.training: {} **'.format(self.training))
        batch_size = chars.size(0)
        max_seq_len = chars.size(1)
        chars_embed = self.char_embed(chars)
        if self.use_bigram:
            bigrams_embed = self.bigram_embed(bigrams)
            embedding = torch.cat([chars_embed, bigrams_embed], dim=-1)
        else:
            embedding = chars_embed
        if self.embed_dropout_pos == '0':
            embedding = self.embed_dropout(embedding)

        embedding = self.w_proj(embedding)
        if self.embed_dropout_pos == '1':
            embedding = self.embed_dropout(embedding)

        if self.use_abs_pos:
            embedding = self.pos_encode(embedding)

        if self.embed_dropout_pos == '2':
            embedding = self.embed_dropout(embedding)

        encoded = self.encoder(embedding, seq_len)

        if hasattr(self, 'output_dropout'):
            encoded = self.output_dropout(encoded)

        pred = self.output(encoded)

        mask = seq_len_to_mask(seq_len).bool()

        if self.mode['debug']:
            print('debug mode:finish!')
            exit(1208)
        if self.training:
            loss = self.crf(pred, target, mask).mean(dim=0)
            if self.self_supervised:
                # print('self supervised loss added!')
                chars_pred = self.output_self_supervised(encoded)
                chars_pred = chars_pred.view(
                    size=[batch_size * max_seq_len, -1])
                chars_target = chars_target.view(
                    size=[batch_size * max_seq_len])
                self_supervised_loss = self.loss_func(chars_pred, chars_target)
                # print('self_supervised_loss:{}'.format(self_supervised_loss))
                # print('supervised_loss:{}'.format(loss))
                loss += self_supervised_loss
            return {'loss': loss}
        else:
            pred, path = self.crf.viterbi_decode(pred, mask)
            result = {'pred': pred}
            if self.self_supervised:
                chars_pred = self.output_self_supervised(encoded)
                result['chars_pred'] = chars_pred

            return result
Пример #10
0
    def forward(self,
                chars,
                bigrams,
                seq_len,
                target,
                skips_l2r_source,
                skips_l2r_word,
                lexicon_count,
                skips_r2l_source=None,
                skips_r2l_word=None,
                lexicon_count_back=None):
        # print('skips_l2r_word_id:{}'.format(skips_l2r_word.size()))
        batch = chars.size(0)
        max_seq_len = chars.size(1)
        # max_lexicon_count = skips_l2r_word.size(2)

        embed_char = self.char_embed(chars)
        if self.use_bigram:

            embed_bigram = self.bigram_embed(bigrams)

            embedding = torch.cat([embed_char, embed_bigram], dim=-1)
        else:

            embedding = embed_char

        embed_nonword = self.embed_dropout(embedding)

        # skips_l2r_word = torch.reshape(skips_l2r_word,shape=[batch,-1])
        embed_word = self.word_embed(skips_l2r_word)
        embed_word = self.embed_dropout(embed_word)
        # embed_word = torch.reshape(embed_word,shape=[batch,max_seq_len,max_lexicon_count,-1])

        encoded_h, encoded_c = self.encoder(embed_nonword, seq_len,
                                            skips_l2r_source, embed_word,
                                            lexicon_count)

        if self.bidirectional:
            embed_word_back = self.word_embed(skips_r2l_word)
            embed_word_back = self.embed_dropout(embed_word_back)
            encoded_h_back, encoded_c_back = self.encoder_back(
                embed_nonword, seq_len, skips_r2l_source, embed_word_back,
                lexicon_count_back)
            encoded_h = torch.cat([encoded_h, encoded_h_back], dim=-1)

        encoded_h = self.output_dropout(encoded_h)

        pred = self.output(encoded_h)

        mask = seq_len_to_mask(seq_len)

        if self.training:
            loss = self.crf(pred, target, mask)
            return {'loss': loss}
        else:
            pred, path = self.crf.viterbi_decode(pred, mask)
            return {'pred': pred}
Пример #11
0
    def forward(self, feats, seq_lens, gold_heads=None):
        """
        max_len是包含root的
        :param chars: batch_size x max_len
        :param ngrams: batch_size x max_len*ngram_per_char
        :param seq_lens: batch_size
        :param gold_heads: batch_size x max_len
        :param pre_chars: batch_size x max_len
        :param pre_ngrams: batch_size x max_len*ngram_per_char
        :return dict: parsing results
            arc_pred: [batch_size, seq_len, seq_len]
            label_pred: [batch_size, seq_len, seq_len]
            mask: [batch_size, seq_len]
            head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads
        """
        # prepare embeddings
        batch_size,seq_len,_ = feats.shape
        # print('forward {} {}'.format(batch_size, seq_len))

        # get sequence mask
        mask = seq_len_to_mask(seq_lens).long()

        # for arc biaffine
        # mlp, reduce dim
        feat = self.mlp(feats)
        arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size
        arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz]
        label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:]

        # biaffine arc classifier
        arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L]
        # use gold or predicted arc to predict label
        if gold_heads is None or not self.training:
            # use greedy decoding in training
            if self.training or self.use_greedy_infer:
                heads = self.greedy_decoder(arc_pred, mask)
            else:
                heads = self.mst_decoder(arc_pred, mask)
            head_pred = heads
        else:
            assert self.training # must be training mode
            if gold_heads is None:
                heads = self.greedy_decoder(arc_pred, mask)
                head_pred = heads
            else:
                head_pred = None
                heads = gold_heads
        # heads: batch_size x max_len

        batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=feats.device).unsqueeze(1)
        label_head = label_head[batch_range, heads].contiguous()
        label_pred = self.label_predictor(label_head, label_dep) # [N, max_len, num_label]

        res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask}
        if head_pred is not None:
            res_dict['head_pred'] = head_pred
        return res_dict
Пример #12
0
    def get_batch_generation(self, samples_list, try_cuda=True):
        if not samples_list:
            return None
        if try_cuda:
            self.try_cuda()

        tensor_list = []
        masked_indices_list = []
        max_len = 0
        output_tokens_list = []
        seq_len = []
        for sample in samples_list:
            masked_inputs_list = sample["masked_sentences"]

            tokens_list = [self.tokenizer.bos_token_id]

            for idx, masked_input in enumerate(masked_inputs_list):
                tokens_list.extend(
                    self.tokenizer.encode(" " + masked_input.strip(),
                                          add_special_tokens=False))
                tokens_list.append(self.tokenizer.eos_token_id)

            # tokens = torch.cat(tokens_list)[: self.max_sentence_length]
            tokens = torch.tensor(tokens_list)[:self.max_sentence_length]
            output_tokens_list.append(tokens.long().cpu().numpy())

            seq_len.append(len(tokens))
            if len(tokens) > max_len:
                max_len = len(tokens)
            tensor_list.append(tokens)
            masked_index = (
                tokens == self.tokenizer.mask_token_id).nonzero().numpy()
            for x in masked_index:
                masked_indices_list.append([x[0]])
        tokens_list = []
        for tokens in tensor_list:
            pad_lenght = max_len - len(tokens)
            if pad_lenght > 0:
                pad_tensor = torch.full([pad_lenght],
                                        self.tokenizer.pad_token_id,
                                        dtype=torch.int)
                tokens = torch.cat((tokens, pad_tensor.long()))
            tokens_list.append(tokens)

        batch_tokens = torch.stack(tokens_list)
        seq_len = torch.LongTensor(seq_len)
        attn_mask = seq_len_to_mask(seq_len)

        with torch.no_grad():
            # with utils.eval(self.model.model):
            self.model.eval()
            outputs = self.model(
                batch_tokens.long().to(device=self._model_device),
                attention_mask=attn_mask.to(device=self._model_device))
            log_probs = outputs[0]

        return log_probs.cpu(), output_tokens_list, masked_indices_list
def prepare_env():
    vocab = Vocabulary().add_word_lst("This is a test .".split())
    vocab.add_word_lst("Another test !".split())
    embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5)

    encoder_output = torch.randn(2, 3, 10)
    src_seq_len = torch.LongTensor([3, 2])
    encoder_mask = seq_len_to_mask(src_seq_len)

    return embed, encoder_output, encoder_mask
Пример #14
0
def train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len):
    optimizer = optim.Adam(model.parameters(), lr=1e-2)
    mask = seq_len_to_mask(tgt_seq_len).eq(0)
    target = tgt_words_idx.masked_fill(mask, -100)

    for i in range(100):
        optimizer.zero_grad()
        pred = model(src_words_idx, tgt_words_idx,
                     src_seq_len)['pred']  # bsz x max_len x vocab_size
        loss = F.cross_entropy(pred.transpose(1, 2), target)
        loss.backward()
        optimizer.step()

    right_count = pred.argmax(dim=-1).eq(target).masked_fill(mask, 1).sum()
    return right_count
Пример #15
0
    def forward(self, words, seq_len=None):
        """

        :param torch.LongTensor words: [batch_size, seq_len],句子中word的index
        :param torch.LongTensor seq_len:  [batch,] 每个句子的长度
        :return output: dict of torch.LongTensor, [batch_size, num_classes]
        """
        x = self.embed(words)  # [N,L] -> [N,L,C]
        if seq_len is not None:
            mask = seq_len_to_mask(seq_len)
            x = self.conv_pool(x, mask)
        else:
            x = self.conv_pool(x)  # [N,L,C] -> [N,C]
        x = self.dropout(x)
        x = self.fc(x)  # [N,C] -> [N, N_class]
        return {C.OUTPUT: x}
Пример #16
0
    def _check_potentials(self, scores, lengths=None):
        semiring = self.semiring
        batch, N, N2, N3 = self._get_dimension_and_requires_grad(scores)
        assert N == N2 == N3, "Non-square potentials"

        if lengths is None:
            lengths = torch.LongTensor([N - 1] * batch).to(scores.device)
        else:
            assert max(lengths) <= N, "Length longer than N"

        scores = semiring.convert(scores)
        scores = scores.clone()  # avoid leaf error when backward

        mask = seq_len_to_mask(lengths + 1, N)
        mask3d = (mask.unsqueeze(-1) *
                  mask.unsqueeze(-2)).unsqueeze(-1) * mask.view(
                      batch, 1, 1, N)
        semiring.zero_mask_(scores, ~mask3d)

        return scores, batch, N, lengths
Пример #17
0
    def _forward(self, chars, bigrams, trigrams, seq_len, target=None):
        chars = self.char_embed(chars)
        if bigrams is not None:
            bigrams = self.bigram_embed(bigrams)
            chars = torch.cat([chars, bigrams], dim=-1)
        if trigrams is not None:
            trigrams = self.trigram_embed(trigrams)
            chars = torch.cat([chars, trigrams], dim=-1)

        output, _ = self.lstm(chars, seq_len)
        output = self.dropout(output)
        output = self.fc(output)
        output = F.log_softmax(output, dim=-1)
        mask = seq_len_to_mask(seq_len)
        if target is None:
            pred, _ = self.crf.viterbi_decode(output, mask)
            return {Const.OUTPUT: pred}
        else:
            loss = self.crf.forward(output, tags=target, mask=mask)
            return {Const.LOSS: loss}
Пример #18
0
    def forward(self, chars, bigrams, seq_len, target):
        if self.debug:

            print_info('chars:{}'.format(chars.size()))
            print_info('bigrams:{}'.format(bigrams.size()))
            print_info('seq_len:{}'.format(seq_len.size()))
            print_info('target:{}'.format(target.size()))
        embed_char = self.char_embed(chars)

        if self.use_bigram:

            embed_bigram = self.bigram_embed(bigrams)

            embedding = torch.cat([embed_char, embed_bigram], dim=-1)
        else:

            embedding = embed_char

        embedding = self.embed_dropout(embedding)

        encoded_h, encoded_c = self.encoder(embedding, seq_len)

        encoded_h = self.output_dropout(encoded_h)

        pred = self.output(encoded_h)

        mask = seq_len_to_mask(seq_len)

        # pred = self.crf(pred)

        # batch_size, sent_len = pred.shape[0], pred.shape[1]
        # loss = self.loss_func(pred.reshape(batch_size * sent_len, -1), target.reshape(batch_size * sent_len))
        if self.debug:
            print('debug mode:finish')
            exit(1208)
        if self.training:
            loss = self.crf(pred, target, mask)
            return {'loss': loss}
        else:
            pred, path = self.crf.viterbi_decode(pred, mask)
            return {'pred': pred}
Пример #19
0
    def forward(self,
                lattice,
                bigrams,
                seq_len,
                lex_num,
                pos_s,
                pos_e,
                pos_tag,
                target=None,
                chars_target=None):
        batch_size = lattice.size(0)
        max_seq_len_and_lex_num = lattice.size(1)
        max_seq_len = bigrams.size(1)

        words = lattice[:, :max_seq_len]
        mask = seq_len_to_mask(seq_len).bool()
        words.masked_fill_((~mask), self.vocabs['lattice'].padding_idx)

        encoded = self.bert_embedding(words)

        if self.use_pos_tag:
            pos_embed = self.pos_embedding(pos_tag)
            encoded = torch.cat([encoded, pos_embed], dim=-1)

        if self.after_bert == 'lstm':
            encoded, _ = self.lstm(encoded, seq_len)
            encoded = self.dropout(encoded)

        pred = self.output(encoded)
        if self.training:
            # loss = self.crf(pred, target, mask).mean(dim=0)
            loss = self.crf(emissions=pred, tags=target, mask=mask).mean(dim=0)
            return {'loss': -loss}
        else:
            pred = self.crf.decode(emissions=pred, mask=mask).squeeze(0)
            # pred, path = self.crf.viterbi_decode(pred, mask)
            # print(pred.shape)
            result = {'pred': pred}
            return result
Пример #20
0
 def _forward(self,
              chars,
              bigrams=None,
              trigrams=None,
              seq_len=None,
              target=None):
     chars = self.char_embed(chars)
     if hasattr(self, 'bigram_embed'):
         bigrams = self.bigram_embed(bigrams)
         chars = torch.cat((chars, bigrams), dim=-1)
     if hasattr(self, 'trigram_embed'):
         trigrams = self.trigram_embed(trigrams)
         chars = torch.cat((chars, trigrams), dim=-1)
     feats, _ = self.lstm(chars, seq_len=seq_len)
     feats = self.fc(feats)
     feats = self.dropout(feats)
     logits = F.log_softmax(feats, dim=-1)
     mask = seq_len_to_mask(seq_len)
     if target is None:
         pred, _ = self.crf.viterbi_decode(logits, mask)
         return {C.OUTPUT: pred}
     else:
         loss = self.crf(logits, target, mask).mean()
         return {C.LOSS: loss}
Пример #21
0
    def forward(self,
                lattice,
                bigrams,
                seq_len,
                lex_num,
                pos_s,
                pos_e,
                target,
                chars_target=None):
        if self.mode['debug']:
            print('lattice:{}'.format(lattice))
            print('bigrams:{}'.format(bigrams))
            print('seq_len:{}'.format(seq_len))
            print('lex_num:{}'.format(lex_num))
            print('pos_s:{}'.format(pos_s))
            print('pos_e:{}'.format(pos_e))

        batch_size = lattice.size(0)
        max_seq_len_and_lex_num = lattice.size(1)
        max_seq_len = bigrams.size(1)

        raw_embed = self.lattice_embed(lattice)
        # raw_embed 是字和词的pretrain的embedding,但是是分别trian的,所以需要区分对待
        if self.use_bigram:
            bigrams_embed = self.bigram_embed(bigrams)
            bigrams_embed = torch.cat([
                bigrams_embed,
                torch.zeros(size=[
                    batch_size, max_seq_len_and_lex_num -
                    max_seq_len, self.bigram_size
                ]).to(bigrams_embed)
            ],
                                      dim=1)
            raw_embed_char = torch.cat([raw_embed, bigrams_embed], dim=-1)
        else:
            raw_embed_char = raw_embed
        # print('raw_embed_char_1:{}'.format(raw_embed_char[:1,:3,-5:]))

        if self.use_bert:
            bert_pad_length = lattice.size(1) - max_seq_len
            char_for_bert = lattice[:, :max_seq_len]
            mask = seq_len_to_mask(seq_len).bool()
            char_for_bert = char_for_bert.masked_fill(
                (~mask), self.vocabs['lattice'].padding_idx)
            bert_embed = self.bert_embedding(char_for_bert)
            bert_embed = torch.cat([
                bert_embed,
                torch.zeros(
                    size=[batch_size, bert_pad_length,
                          bert_embed.size(-1)],
                    device=bert_embed.device,
                    requires_grad=False)
            ],
                                   dim=-2)
            # print('bert_embed:{}'.format(bert_embed[:1, :3, -5:]))
            raw_embed_char = torch.cat([raw_embed_char, bert_embed], dim=-1)

        # print('raw_embed_char:{}'.format(raw_embed_char[:1,:3,-5:]))

        if self.embed_dropout_pos == '0':
            raw_embed_char = self.embed_dropout(raw_embed_char)
            raw_embed = self.gaz_dropout(raw_embed)

        # print('raw_embed_char_dp:{}'.format(raw_embed_char[:1,:3,-5:]))

        embed_char = self.char_proj(raw_embed_char)
        # print('char_proj:',list(self.char_proj.parameters())[0].data[:2][:2])
        # print('embed_char_:{}'.format(embed_char[:1,:3,:4]))

        if self.mode['debug']:
            print('embed_char:{}'.format(embed_char[:2]))
        char_mask = seq_len_to_mask(seq_len,
                                    max_len=max_seq_len_and_lex_num).bool()
        # if self.embed_dropout_pos == '1':
        #     embed_char = self.embed_dropout(embed_char)
        embed_char.masked_fill_(~(char_mask.unsqueeze(-1)), 0)

        embed_lex = self.lex_proj(raw_embed)
        if self.mode['debug']:
            print('embed_lex:{}'.format(embed_lex[:2]))
        # if self.embed_dropout_pos == '1':
        #     embed_lex = self.embed_dropout(embed_lex)

        lex_mask = (seq_len_to_mask(seq_len + lex_num).bool()
                    ^ char_mask.bool())
        embed_lex.masked_fill_(~(lex_mask).unsqueeze(-1), 0)

        assert char_mask.size(1) == lex_mask.size(1)
        # print('embed_char:{}'.format(embed_char[:1,:3,:4]))
        # print('embed_lex:{}'.format(embed_lex[:1,:3,:4]))

        embedding = embed_char + embed_lex
        if self.mode['debug']:
            print('embedding:{}'.format(embedding[:2]))

        if self.embed_dropout_pos == '1':
            embedding = self.embed_dropout(embedding)

        if self.use_abs_pos:
            embedding = self.abs_pos_encode(embedding, pos_s, pos_e)

        if self.embed_dropout_pos == '2':
            embedding = self.embed_dropout(embedding)
        # embedding = self.embed_dropout(embedding)
        # print('*1*')
        # print(embedding.size())
        # print('merged_embedding:{}'.format(embedding[:1,:3,:4]))
        # exit()
        encoded = self.encoder(embedding,
                               seq_len,
                               lex_num=lex_num,
                               pos_s=pos_s,
                               pos_e=pos_e)

        if hasattr(self, 'output_dropout'):
            encoded = self.output_dropout(encoded)

        encoded = encoded[:, :max_seq_len, :]
        pred = self.output(encoded)

        mask = seq_len_to_mask(seq_len).bool()

        if self.mode['debug']:
            print('debug mode:finish!')
            exit(1208)
        if self.training:
            loss = self.crf(pred, target, mask).mean(dim=0)
            if self.self_supervised:
                # print('self supervised loss added!')
                chars_pred = self.output_self_supervised(encoded)
                chars_pred = chars_pred.view(
                    size=[batch_size * max_seq_len, -1])
                chars_target = chars_target.view(
                    size=[batch_size * max_seq_len])
                self_supervised_loss = self.loss_func(chars_pred, chars_target)
                # print('self_supervised_loss:{}'.format(self_supervised_loss))
                # print('supervised_loss:{}'.format(loss))
                loss += self_supervised_loss
            return {'loss': loss}
        else:
            pred, path = self.crf.viterbi_decode(pred, mask)
            result = {'pred': pred}
            if self.self_supervised:
                chars_pred = self.output_self_supervised(encoded)
                result['chars_pred'] = chars_pred

            return result
Пример #22
0
    def forward(self, lattice, bigrams, seq_len, lex_num, pos_s, pos_e,
                target, span_label, attr_start_label, attr_end_label, chars_target=None):
        self.steps += 1
        if self.mode['debug']:
            print('lattice:{} {}'.format(lattice.shape, lattice))
            print('bigrams:{} {}'.format(bigrams.shape, bigrams))
            print('seq_len:{} {}'.format(seq_len.shape, seq_len))
            print('lex_num:{} {}'.format(lex_num.shape, lex_num))
            print('pos_s:{} {}'.format(pos_s.shape, pos_s))
            print('pos_e:{} {}'.format(pos_e.shape, pos_e))
            print('span_label:{} {}'.format(span_label.shape, span_label))
            print('attr_start_label:{} {}'.format(attr_start_label.shape, attr_start_label))
            print('attr_end_label: {} {}'.format(attr_end_label.shape, attr_end_label))
            exit(1228)

        batch_size = lattice.size(0)
        max_seq_len_and_lex_num = lattice.size(1)
        max_seq_len = bigrams.size(1)

        raw_embed = self.lattice_embed(lattice)
        # raw_embed 是字和词的pretrain的embedding,但是是分别trian的,所以需要区分对待
        if self.use_bigram:
            bigrams_embed = self.bigram_embed(bigrams)
            bigrams_embed = torch.cat([bigrams_embed,
                                       torch.zeros(size=[batch_size, max_seq_len_and_lex_num - max_seq_len,
                                                         self.bigram_size]).to(bigrams_embed)], dim=1)
            raw_embed_char = torch.cat([raw_embed, bigrams_embed], dim=-1)
        else:
            raw_embed_char = raw_embed
        # print('raw_embed_char_1:{}'.format(raw_embed_char[:1,:3,-5:]))

        if self.use_bert:
            bert_pad_length = lattice.size(1) - max_seq_len
            char_for_bert = lattice[:, :max_seq_len]
            mask = seq_len_to_mask(seq_len).bool()
            char_for_bert = char_for_bert.masked_fill((~mask), self.vocabs['lattice'].padding_idx)
            bert_embed = self.bert_embedding(char_for_bert)
            bert_embed = torch.cat([bert_embed,
                                    torch.zeros(size=[batch_size, bert_pad_length, bert_embed.size(-1)],
                                                device=bert_embed.device,
                                                requires_grad=False)], dim=-2)
            # print('bert_embed:{}'.format(bert_embed[:1, :3, -5:]))
            raw_embed_char = torch.cat([raw_embed_char, bert_embed], dim=-1)

        # print('raw_embed_char:{}'.format(raw_embed_char[:1,:3,-5:]))

        if self.embed_dropout_pos == '0':
            raw_embed_char = self.embed_dropout(raw_embed_char)
            raw_embed = self.gaz_dropout(raw_embed)

        # print('raw_embed_char_dp:{}'.format(raw_embed_char[:1,:3,-5:]))

        embed_char = self.char_proj(raw_embed_char)
        # print('char_proj:',list(self.char_proj.parameters())[0].data[:2][:2])
        # print('embed_char_:{}'.format(embed_char[:1,:3,:4]))

        if self.mode['debug']:
            print('embed_char:{}'.format(embed_char[:2]))
        char_mask = seq_len_to_mask(seq_len, max_len=max_seq_len_and_lex_num).bool()
        # if self.embed_dropout_pos == '1':
        #     embed_char = self.embed_dropout(embed_char)
        embed_char.masked_fill_(~(char_mask.unsqueeze(-1)), 0)

        embed_lex = self.lex_proj(raw_embed)
        if self.mode['debug']:
            print('embed_lex:{}'.format(embed_lex[:2]))
        # if self.embed_dropout_pos == '1':
        #     embed_lex = self.embed_dropout(embed_lex)

        lex_mask = (seq_len_to_mask(seq_len + lex_num).bool() ^ char_mask.bool())
        embed_lex.masked_fill_(~(lex_mask).unsqueeze(-1), 0)

        assert char_mask.size(1) == lex_mask.size(1)
        # print('embed_char:{}'.format(embed_char[:1,:3,:4]))
        # print('embed_lex:{}'.format(embed_lex[:1,:3,:4]))

        embedding = embed_char + embed_lex
        if self.mode['debug']:
            print('embedding:{}'.format(embedding[:2]))

        if self.embed_dropout_pos == '1':
            embedding = self.embed_dropout(embedding)

        if self.use_abs_pos:
            embedding = self.abs_pos_encode(embedding, pos_s, pos_e)

        if self.embed_dropout_pos == '2':
            embedding = self.embed_dropout(embedding)
        # embedding = self.embed_dropout(embedding)
        # print('*1*')
        # print(embedding.size())
        # print('merged_embedding:{}'.format(embedding[:1,:3,:4]))
        # exit()

        mask = seq_len_to_mask(seq_len).bool()
        # TODO: add ours PLE
        if self.new_tag_scheme:
            encodeds = []
            for _i in range(self.ple_channel_num):
                encoded = self.encoder_list[_i](embedding, seq_len, lex_num=lex_num, pos_s=pos_s, pos_e=pos_e)
                if hasattr(self, 'output_dropout'):
                    encoded = self.output_dropout(encoded)
                encoded = encoded[:, :max_seq_len, :]
                encodeds.append(encoded)
            if self.ple_channel_num == 1:
                span_logits, attr_start_logits, attr_end_logits = self.ple(encodeds[0], encodeds[0], encodeds[0])
            else:
                span_logits, attr_start_logits, attr_end_logits = self.ple(encodeds[0], encodeds[1], encodeds[2])
            if self.training:
                inputs_seq_len = mask.sum(dim=-1).float()
                span_loss = (self.crf(span_logits, span_label, mask) / inputs_seq_len).mean(dim=0)
                attr_start_loss = self.attr_criterion(attr_start_logits.permute(0, 2, 1), attr_start_label)  # B * S
                attr_start_loss = (torch.sum(attr_start_loss * mask.float(), dim=-1).float() / inputs_seq_len).mean()  # B
                attr_end_loss = self.attr_criterion(attr_end_logits.permute(0, 2, 1), attr_end_label)  # B * S
                attr_end_loss = (torch.sum(attr_end_loss * mask.float(), dim=-1).float() / inputs_seq_len).mean()  # B
                loss = (self.span_loss_alpha * span_loss + attr_start_loss + attr_end_loss) / 3
                # if torch.isnan(span_loss.mean()) or torch.abs(span_loss.mean()) > 50:
                if self.steps % 50 == 0:
                    print(f"span_loss: {span_loss}; attr_start_loss: {attr_start_loss}; attr_end_loss: {attr_end_loss}")
                # loss = (attr_start_loss + attr_end_loss) / 3
                return {"loss": loss}
            else:
                # span_pred, path = self.crf.viterbi_decode(span_logits, mask)
                attr_start_pred = attr_start_logits.argmax(dim=-1)
                attr_end_pred = attr_end_logits.argmax(dim=-1)
                ner_pred = convert_attr_seq_to_ner_seq(attr_start_pred, attr_end_pred, self.vocabs, tagscheme='BMOES')
                return {'pred': ner_pred}
        else:
            encoded = self.encoder(embedding, seq_len, lex_num=lex_num, pos_s=pos_s, pos_e=pos_e)
            if hasattr(self, 'output_dropout'):
                encoded = self.output_dropout(encoded)
            encoded = encoded[:, :max_seq_len, :]
            pred = self.output(encoded)
            if self.mode['debug']:
                print('debug mode:finish!')
                exit(1208)
            if self.training:
                loss = self.crf(pred, target, mask).mean(dim=0)
                if self.self_supervised:
                    # print('self supervised loss added!')
                    chars_pred = self.output_self_supervised(encoded)
                    chars_pred = chars_pred.view(size=[batch_size * max_seq_len, -1])
                    chars_target = chars_target.view(size=[batch_size * max_seq_len])
                    self_supervised_loss = self.loss_func(chars_pred, chars_target)
                    # print('self_supervised_loss:{}'.format(self_supervised_loss))
                    # print('supervised_loss:{}'.format(loss))
                    loss += self_supervised_loss
                return {'loss': loss}
            else:
                pred, path = self.crf.viterbi_decode(pred, mask)
                result = {'pred': pred}
                if self.self_supervised:
                    chars_pred = self.output_self_supervised(encoded)
                    result['chars_pred'] = chars_pred

                return result
Пример #23
0
    def forward(self, chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=None, pre_bigrams=None,
                pre_trigrams=None):
        """
        max_len是包含root的
        :param chars: batch_size x max_len
        :param ngrams: batch_size x max_len*ngram_per_char
        :param seq_lens: batch_size
        :param gold_heads: batch_size x max_len
        :param pre_chars: batch_size x max_len
        :param pre_ngrams: batch_size x max_len*ngram_per_char
        :return dict: parsing results
            arc_pred: [batch_size, seq_len, seq_len]
            label_pred: [batch_size, seq_len, seq_len]
            mask: [batch_size, seq_len]
            head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads
        """
        # prepare embeddings
        batch_size, seq_len = chars.shape
        # print('forward {} {}'.format(batch_size, seq_len))

        # get sequence mask
        mask = seq_len_to_mask(seq_lens).long()

        chars = self.char_embed(chars) # [N,L] -> [N,L,C_0]
        bigrams = self.bigram_embed(bigrams) # [N,L] -> [N,L,C_1]
        trigrams = self.trigram_embed(trigrams)

        if pre_chars is not None:
            pre_chars = self.pre_char_embed(pre_chars)
            # pre_chars = self.pre_char_fc(pre_chars)
            chars = pre_chars + chars
        if pre_bigrams is not None:
            pre_bigrams = self.pre_bigram_embed(pre_bigrams)
            # pre_bigrams = self.pre_bigram_fc(pre_bigrams)
            bigrams = bigrams + pre_bigrams
        if pre_trigrams is not None:
            pre_trigrams = self.pre_trigram_embed(pre_trigrams)
            # pre_trigrams = self.pre_trigram_fc(pre_trigrams)
            trigrams = trigrams + pre_trigrams

        x = torch.cat([chars, bigrams, trigrams], dim=2) # -> [N,L,C]

        # encoder, extract features
        if self.training:
            x = drop_input_independent(x, self.dropout)
        sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True)
        x = x[sort_idx]
        x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True)
        feat, _ = self.encoder(x)  # -> [N,L,C]
        feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True)
        _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
        feat = feat[unsort_idx]
        feat = self.timestep_drop(feat)

        # for arc biaffine
        # mlp, reduce dim
        feat = self.mlp(feat)
        arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size
        arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz]
        label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:]

        # biaffine arc classifier
        arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L]

        # use gold or predicted arc to predict label
        if gold_heads is None or not self.training:
            # use greedy decoding in training
            if self.training or self.use_greedy_infer:
                heads = self.greedy_decoder(arc_pred, mask)
            else:
                heads = self.mst_decoder(arc_pred, mask)
            head_pred = heads
        else:
            assert self.training # must be training mode
            if gold_heads is None:
                heads = self.greedy_decoder(arc_pred, mask)
                head_pred = heads
            else:
                head_pred = None
                heads = gold_heads
        # heads: batch_size x max_len

        batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=chars.device).unsqueeze(1)
        label_head = label_head[batch_range, heads].contiguous()
        label_pred = self.label_predictor(label_head, label_dep) # [N, max_len, num_label]
        # 这里限制一下,只有当head为下一个时,才能预测app这个label
        arange_index = torch.arange(1, seq_len+1, dtype=torch.long, device=chars.device).unsqueeze(0)\
            .repeat(batch_size, 1) # batch_size x max_len
        app_masks = heads.ne(arange_index) #  batch_size x max_len, 为1的位置不可以预测app
        app_masks = app_masks.unsqueeze(2).repeat(1, 1, self.num_label)
        app_masks[:, :, 1:] = 0
        label_pred = label_pred.masked_fill(app_masks, -np.inf)

        res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask}
        if head_pred is not None:
            res_dict['head_pred'] = head_pred
        return res_dict
Пример #24
0
    def forward(self, sentence, aspect, pos_class, dep_tags, text_len,
                aspect_len, dep_rels, dep_heads, aspect_position, dep_dirs):
        '''
        Forward takes:
            sentence: sentence_id of size (batch_size, text_length)
            aspect: aspect_id of size (batch_size, aspect_length)
            pos_class: pos_tag_id of size (batch_size, text_length)
            dep_tags: dep_tag_id of size (batch_size, text_length)
            text_len: (batch_size,) length of each sentence
            aspect_len: (batch_size, ) aspect length of each sentence
            dep_rels: (batch_size, text_length) relation
            dep_heads: (batch_size, text_length) which node adjacent to that node
            aspect_position: (batch_size, text_length) mask, with the position of aspect as 1 and others as 0
            dep_dirs: (batch_size, text_length) the directions each node to the aspect
        '''
        fmask = seq_len_to_mask(text_len).float()
        # fmask = (torch.zeros_like(sentence) != sentence).float()  # (N,L), pad为0
        # dmask = (torch.zeros_like(dep_tags) != dep_tags).float()  # (N ,L)
        if self.training:
            mask = torch.rand(sentence.size()).lt(0.02).to(sentence.device)
            sentence = sentence.masked_fill(mask, 0)
            # mask = torch.rand(aspect.size()).lt(0.01).to(sentence.device)
            # aspect = aspect.masked_fill(mask, 0)

        feature = self.embed(sentence)  # (N, L, D)
        aspect_feature = self.embed(aspect)  # (N, L', D)
        feature = self.dropout(feature)
        aspect_feature = self.dropout(aspect_feature)

        if self.args.highway:
            feature = self.highway(feature)
            aspect_feature = self.highway(aspect_feature)

        feature, _ = self.bilstm(feature, seq_len=text_len)  # (N,L,D)
        aspect_feature, _ = self.bilstm(aspect_feature,
                                        seq_len=aspect_len)  #(N,L,D)
        aspect_mask = seq_len_to_mask(aspect_len)

        # aspect_feature = aspect_feature.masked_fill(aspect_mask.eq(0).unsqueeze(-1), 0)
        # aspect_feature = aspect_feature.sum(dim=1)/aspect_len.unsqueeze(1).float()
        aspect_feature = aspect_feature.masked_fill(
            aspect_mask.eq(0).unsqueeze(-1), -10000)
        aspect_feature, _ = aspect_feature.max(dim=1)
        # aspect_feature = aspect_feature.mean(dim=1)

        ############################################################################################
        # do gat thing
        dep_feature = self.dep_embed(dep_tags)
        # dep_feature = self.dropout(dep_feature)
        dep_feature = F.dropout(dep_feature, p=0.7, training=self.training)
        if self.args.highway:
            dep_feature = self.highway_dep(dep_feature)

        dep_out = [
            g(feature, dep_feature, fmask).unsqueeze(1) for g in self.gat_dep
        ]  # (N, 1, D) * num_heads
        dep_out = torch.cat(dep_out, dim=1)  # (N, H, D)
        # dep_out = dep_out.mean(dim = 1) # (N, D)
        dep_out, _ = dep_out.max(dim=1)  # (N, D)

        if self.args.gat_attention_type == 'gcn':
            gat_out = self.gat(feature)  # (N, L, D)
            fmask = fmask.unsqueeze(2)
            gat_out = gat_out * fmask
            gat_out = F.relu(torch.sum(gat_out, dim=1))  # (N, D)
        else:
            gat_out = [
                g(feature, aspect_feature, fmask).unsqueeze(1)
                for g in self.gat
            ]
            gat_out = torch.cat(gat_out, dim=1)
            # gat_out = gat_out.mean(dim=1)
            gat_out, _ = gat_out.max(dim=1)

        feature_out = torch.cat([dep_out, gat_out], dim=1)  # (N, D')
        # feature_out = gat_out
        #############################################################################################
        feature_out = self.dropout(feature_out)
        # feature_out = F.dropout(feature_out, p=0.3, training=self.training)
        x = self.fcs(feature_out)
        logit = self.fc_final(x)
        return logit
    def forward(self,
                lattice,
                bigrams,
                seq_len,
                lex_num,
                pos_s,
                pos_e,
                target,
                chars_target=None):

        if self.mode['debug']:  #以第一个sample为例的话
            print('lattice:{}'.format(lattice))  #21+12个idx后面填充
            print('bigrams:{}'.format(bigrams))  #21个idx,跟lattice的开头还不同
            print('seq_len:{}'.format(seq_len))  #21
            print('lex_num:{}'.format(lex_num))  #12
            print('pos_s:{}'.format(pos_s))  #0,1,2,
            print('pos_e:{}'.format(pos_e))  #0,1,2,

        batch_size = lattice.size(0)
        max_seq_len_and_lex_num = lattice.size(1)
        max_seq_len = bigrams.size(1)

        raw_embed = self.lattice_embed(lattice)  #取lattice的embedding

        if self.use_bigram:
            bigrams_embed = self.bigram_embed(bigrams)
            bigrams_embed = torch.cat([
                bigrams_embed,
                torch.zeros(size=[
                    batch_size, max_seq_len_and_lex_num -
                    max_seq_len, self.bigram_size
                ]).to(bigrams_embed)
            ],
                                      dim=1)
            raw_embed_char = torch.cat([raw_embed, bigrams_embed], dim=-1)  #
        else:
            raw_embed_char = raw_embed

        # dim2 = 0
        # dim3 = 2

        if self.embed_dropout_pos == '0':
            raw_embed_char = self.embed_dropout(raw_embed_char)
            raw_embed = self.gaz_dropout(raw_embed)

        embed_char = self.char_proj(raw_embed_char)  #linear
        char_mask = seq_len_to_mask(seq_len,
                                    max_len=max_seq_len_and_lex_num).bool()
        embed_char.masked_fill_(~(char_mask.unsqueeze(-1)), 0)  #torch.tensor

        embed_lex = self.lex_proj(raw_embed)
        lex_mask = (seq_len_to_mask(seq_len + lex_num).bool()
                    ^ char_mask.bool())
        embed_lex.masked_fill_(~(lex_mask).unsqueeze(-1), 0)

        assert char_mask.size(1) == lex_mask.size(1)

        embedding = embed_char + embed_lex  #这里加的很诡异啊

        if self.embed_dropout_pos == '1':
            embedding = self.embed_dropout(embedding)  #dropout

        if self.use_abs_pos:
            embedding = self.abs_pos_encode(embedding, pos_s, pos_e)

        if self.embed_dropout_pos == '2':
            embedding = self.embed_dropout(embedding)

        encoded = self.encoder(embedding,
                               seq_len,
                               lex_num=lex_num,
                               pos_s=pos_s,
                               pos_e=pos_e,
                               print_=(self.batch_num == 327))

        if hasattr(self, 'output_dropout'):
            encoded = self.output_dropout(encoded)

        encoded = encoded[:, :max_seq_len, :]
        pred = self.output(encoded)

        mask = seq_len_to_mask(seq_len).bool()

        if self.mode['debug']:
            print('debug mode:finish!')
            exit(1208)

        if self.training:
            loss = self.crf(pred, target, mask).mean(dim=0)
            if self.self_supervised:
                # print('self supervised loss added!')
                chars_pred = self.output_self_supervised(encoded)
                chars_pred = chars_pred.view(
                    size=[batch_size * max_seq_len, -1])
                chars_target = chars_target.view(
                    size=[batch_size * max_seq_len])
                self_supervised_loss = self.loss_func(chars_pred, chars_target)
                # print('self_supervised_loss:{}'.format(self_supervised_loss))
                # print('supervised_loss:{}'.format(loss))
                loss += self_supervised_loss

            if self.batch_num == 327:
                print('{} loss:{}'.format(self.batch_num, loss))
                exit()

            # exit()
            return {'loss': loss}
        else:
            pred, path = self.crf.viterbi_decode(pred, mask)
            result = {'pred': pred}
            if self.self_supervised:
                chars_pred = self.output_self_supervised(encoded)
                result['chars_pred'] = chars_pred

            return result
Пример #26
0
    def forward(self,
                lattice,
                bigrams,
                seq_len,
                lex_num,
                pos_s,
                pos_e,
                target,
                chars_target=None):
        # if self.training:
        #     self.batch_num+=1
        # if self.batch_num == 1000:
        #     exit()

        # print('lattice:')
        # print(lattice)
        if self.mode['debug']:
            print('lattice:{}'.format(lattice))
            print('bigrams:{}'.format(bigrams))
            print('seq_len:{}'.format(seq_len))
            print('lex_num:{}'.format(lex_num))
            print('pos_s:{}'.format(pos_s))
            print('pos_e:{}'.format(pos_e))

        batch_size = lattice.size(0)
        max_seq_len_and_lex_num = lattice.size(1)
        max_seq_len = bigrams.size(1)

        raw_embed = self.lattice_embed(lattice)
        # raw_embed 是字和词的pretrain的embedding,但是是分别trian的,所以需要区分对待
        if self.use_bigram:
            bigrams_embed = self.bigram_embed(bigrams)
            bigrams_embed = torch.cat([
                bigrams_embed,
                torch.zeros(size=[
                    batch_size, max_seq_len_and_lex_num -
                    max_seq_len, self.bigram_size
                ]).to(bigrams_embed)
            ],
                                      dim=1)
            raw_embed_char = torch.cat([raw_embed, bigrams_embed], dim=-1)
        else:
            raw_embed_char = raw_embed

        dim2 = 0
        dim3 = 2
        # print('raw_embed:{}'.format(raw_embed[:,dim2,:dim3]))
        # print('raw_embed_char:{}'.format(raw_embed_char[:, dim2, :dim3]))
        if self.embed_dropout_pos == '0':
            raw_embed_char = self.embed_dropout(raw_embed_char)
            raw_embed = self.gaz_dropout(raw_embed)
        # print('raw_embed_dropout:{}'.format(raw_embed[:,dim2,:dim3]))
        # print('raw_embed_char_dropout:{}'.format(raw_embed_char[:, dim2, :dim3]))

        embed_char = self.char_proj(raw_embed_char)
        if self.mode['debug']:
            print('embed_char:{}'.format(embed_char[:2]))
        char_mask = seq_len_to_mask(seq_len,
                                    max_len=max_seq_len_and_lex_num).bool()
        # if self.embed_dropout_pos == '1':
        #     embed_char = self.embed_dropout(embed_char)
        embed_char.masked_fill_(~(char_mask.unsqueeze(-1)), 0)

        embed_lex = self.lex_proj(raw_embed)
        if self.mode['debug']:
            print('embed_lex:{}'.format(embed_lex[:2]))
        # if self.embed_dropout_pos == '1':
        #     embed_lex = self.embed_dropout(embed_lex)

        lex_mask = (seq_len_to_mask(seq_len + lex_num).bool()
                    ^ char_mask.bool())
        embed_lex.masked_fill_(~(lex_mask).unsqueeze(-1), 0)

        assert char_mask.size(1) == lex_mask.size(1)

        embedding = embed_char + embed_lex
        if self.mode['debug']:
            print('embedding:{}'.format(embedding[:2]))

        if self.embed_dropout_pos == '1':
            embedding = self.embed_dropout(embedding)

        if self.use_abs_pos:
            embedding = self.abs_pos_encode(embedding, pos_s, pos_e)

        if self.embed_dropout_pos == '2':
            embedding = self.embed_dropout(embedding)
        # embedding = self.embed_dropout(embedding)

        # print('embedding:{}'.format(embedding[:,dim2,:dim3]))

        if self.batch_num == 327:
            print('{} embed:{}'.format(self.batch_num, embedding[:2,
                                                                 dim2, :dim3]))

        encoded = self.encoder(embedding,
                               seq_len,
                               lex_num=lex_num,
                               pos_s=pos_s,
                               pos_e=pos_e,
                               print_=(self.batch_num == 327))

        if self.batch_num == 327:
            print('{} encoded:{}'.format(self.batch_num, encoded[:2,
                                                                 dim2, :dim3]))

        if hasattr(self, 'output_dropout'):
            encoded = self.output_dropout(encoded)

        encoded = encoded[:, :max_seq_len, :]
        pred = self.output(encoded)

        if self.batch_num == 327:
            print('{} pred:{}'.format(self.batch_num, pred[:2, dim2, :dim3]))

        # print('pred:{}'.format(pred[:,dim2,:dim3]))
        # exit()

        mask = seq_len_to_mask(seq_len).bool()

        if self.mode['debug']:
            print('debug mode:finish!')
            exit(1208)
        if self.training:
            loss = self.crf(pred, target, mask).mean(dim=0)
            if self.self_supervised:
                # print('self supervised loss added!')
                chars_pred = self.output_self_supervised(encoded)
                chars_pred = chars_pred.view(
                    size=[batch_size * max_seq_len, -1])
                chars_target = chars_target.view(
                    size=[batch_size * max_seq_len])
                self_supervised_loss = self.loss_func(chars_pred, chars_target)
                # print('self_supervised_loss:{}'.format(self_supervised_loss))
                # print('supervised_loss:{}'.format(loss))
                loss += self_supervised_loss

            if self.batch_num == 327:
                print('{} loss:{}'.format(self.batch_num, loss))
                exit()

            # exit()
            return {'loss': loss}
        else:
            pred, path = self.crf.viterbi_decode(pred, mask)
            result = {'pred': pred}
            if self.self_supervised:
                chars_pred = self.output_self_supervised(encoded)
                result['chars_pred'] = chars_pred

            return result