Ejemplo n.º 1
0
    def __init__(self,
                 bert_embedding,
                 label_size,
                 vocabs,
                 after_bert,
                 use_pos_tag=True):
        super().__init__()
        self.after_bert = after_bert
        self.bert_embedding = bert_embedding
        self.label_size = label_size
        self.vocabs = vocabs
        self.hidden_size = bert_embedding._embed_size
        self.use_pos_tag = use_pos_tag
        self.pos_feats_size = 0

        if self.use_pos_tag:
            self.pos_embed_size = len(list(vocabs['pos_tag']))
            self.pos_feats_size = 20
            self.pos_embedding = nn.Embedding(self.pos_embed_size,
                                              self.pos_feats_size)

        if self.after_bert == 'lstm':
            self.lstm = nn.LSTM(
                bert_embedding._embed_size + self.pos_feats_size,
                (bert_embedding._embed_size + self.pos_feats_size) // 2,
                bidirectional=True,
                num_layers=2,
            )
        self.output = nn.Linear(self.hidden_size + self.pos_feats_size,
                                self.label_size)
        self.dropout = MyDropout(0.2)
        # self.crf = get_crf_zero_init(self.label_size)
        self.crf = CRF(num_tags=self.label_size, batch_first=True)
Ejemplo n.º 2
0
 def __init__(self,bert_embedding,label_size,vocabs,after_bert):
     super().__init__()
     self.after_bert = after_bert
     self.bert_embedding = bert_embedding
     self.label_size = label_size
     self.vocabs = vocabs
     self.hidden_size = bert_embedding._embed_size
     self.output = nn.Linear(self.hidden_size,self.label_size)
     self.crf = get_crf_zero_init(self.label_size)
     if self.after_bert == 'lstm':
         self.lstm = LSTM(bert_embedding._embed_size,bert_embedding._embed_size//2,
                          bidirectional=True)
     self.dropout = MyDropout(0.5)
    def __init__(self,
                 lattice_embed,
                 bigram_embed,
                 hidden_size,
                 label_size,
                 num_heads,
                 num_layers,
                 use_abs_pos,
                 use_rel_pos,
                 learnable_position,
                 add_position,
                 layer_preprocess_sequence,
                 layer_postprocess_sequence,
                 ff_size=-1,
                 scaled=True,
                 dropout=None,
                 use_bigram=True,
                 mode=collections.defaultdict(bool),
                 dvc=None,
                 vocabs=None,
                 rel_pos_shared=True,
                 max_seq_len=-1,
                 k_proj=True,
                 q_proj=True,
                 v_proj=True,
                 r_proj=True,
                 self_supervised=False,
                 attn_ff=True,
                 pos_norm=False,
                 ff_activate='relu',
                 rel_pos_init=0,
                 abs_pos_fusion_func='concat',
                 embed_dropout_pos='0',
                 four_pos_shared=True,
                 four_pos_fusion=None,
                 four_pos_fusion_shared=True,
                 bert_embedding=None,
                 use_pytorch_dropout=False):
        '''
        :param rel_pos_init: 如果是0,那么从-max_len到max_len的相对位置编码矩阵就按0-2*max_len来初始化,
        如果是1,那么就按-max_len,max_len来初始化

        :param embed_dropout_pos: 如果是0,就直接在embed后dropout,是1就在embed变成hidden size之后再dropout,
        是2就在绝对位置加上之后dropout
        '''
        super().__init__()
        self.use_pytorch_dropout = use_pytorch_dropout
        self.four_pos_fusion_shared = four_pos_fusion_shared
        self.mode = mode
        self.four_pos_shared = four_pos_shared
        self.abs_pos_fusion_func = abs_pos_fusion_func
        self.lattice_embed = lattice_embed
        self.bigram_embed = bigram_embed
        self.hidden_size = hidden_size
        self.label_size = label_size
        self.num_heads = num_heads
        self.num_layers = num_layers
        # self.relative_position = relative_position
        self.use_abs_pos = use_abs_pos
        self.use_rel_pos = use_rel_pos
        if self.use_rel_pos:
            assert four_pos_fusion is not None
        self.four_pos_fusion = four_pos_fusion
        self.learnable_position = learnable_position
        self.add_position = add_position
        self.rel_pos_shared = rel_pos_shared
        self.self_supervised = self_supervised
        self.vocabs = vocabs
        self.attn_ff = attn_ff
        self.pos_norm = pos_norm
        self.ff_activate = ff_activate
        self.rel_pos_init = rel_pos_init
        self.embed_dropout_pos = embed_dropout_pos

        # if self.relative_position:
        #     print('现在还不支持相对编码!')
        #     exit(1208)

        # if self.add_position:
        #     print('暂时只支持位置编码的concat模式')
        #     exit(1208)

        if self.use_rel_pos and max_seq_len < 0:
            print_info('max_seq_len should be set if relative position encode')
            exit(1208)

        self.max_seq_len = max_seq_len

        self.k_proj = k_proj
        self.q_proj = q_proj
        self.v_proj = v_proj
        self.r_proj = r_proj

        self.pe = None

        if self.use_abs_pos:
            self.abs_pos_encode = Absolute_SE_Position_Embedding(
                self.abs_pos_fusion_func,
                self.hidden_size,
                learnable=self.learnable_position,
                mode=self.mode,
                pos_norm=self.pos_norm)

        if self.use_rel_pos:
            pe = get_embedding(max_seq_len,
                               hidden_size,
                               rel_pos_init=self.rel_pos_init)
            pe_sum = pe.sum(dim=-1, keepdim=True)
            if self.pos_norm:
                with torch.no_grad():
                    pe = pe / pe_sum
            self.pe = nn.Parameter(pe, requires_grad=self.learnable_position)
            if self.four_pos_shared:
                self.pe_ss = self.pe
                self.pe_se = self.pe
                self.pe_es = self.pe
                self.pe_ee = self.pe
            else:
                self.pe_ss = nn.Parameter(
                    copy.deepcopy(pe), requires_grad=self.learnable_position)
                self.pe_se = nn.Parameter(
                    copy.deepcopy(pe), requires_grad=self.learnable_position)
                self.pe_es = nn.Parameter(
                    copy.deepcopy(pe), requires_grad=self.learnable_position)
                self.pe_ee = nn.Parameter(
                    copy.deepcopy(pe), requires_grad=self.learnable_position)
        else:
            self.pe = None
            self.pe_ss = None
            self.pe_se = None
            self.pe_es = None
            self.pe_ee = None

        self.layer_preprocess_sequence = layer_preprocess_sequence
        self.layer_postprocess_sequence = layer_postprocess_sequence
        if ff_size == -1:
            ff_size = self.hidden_size
        self.ff_size = ff_size
        self.scaled = scaled
        if dvc == None:
            dvc = 'cpu'
        self.dvc = torch.device(dvc)
        if dropout is None:
            self.dropout = collections.defaultdict(int)
        else:
            self.dropout = dropout
        self.use_bigram = use_bigram

        if self.use_bigram:
            self.bigram_size = self.bigram_embed.embedding.weight.size(1)
            self.char_input_size = self.lattice_embed.embedding.weight.size(
                1) + self.bigram_embed.embedding.weight.size(1)
        else:
            self.char_input_size = self.lattice_embed.embedding.weight.size(1)

        self.lex_input_size = self.lattice_embed.embedding.weight.size(1)

        if use_pytorch_dropout:
            self.embed_dropout = nn.Dropout(self.dropout['embed'])
            self.gaz_dropout = nn.Dropout(self.dropout['gaz'])
            self.output_dropout = nn.Dropout(self.dropout['output'])
        else:
            self.embed_dropout = MyDropout(self.dropout['embed'])
            self.gaz_dropout = MyDropout(self.dropout['gaz'])
            self.output_dropout = MyDropout(self.dropout['output'])

        self.char_proj = nn.Linear(self.char_input_size, self.hidden_size)
        self.lex_proj = nn.Linear(self.lex_input_size, self.hidden_size)

        self.encoder = Transformer_Encoder(
            self.hidden_size,
            self.num_heads,
            self.num_layers,
            relative_position=self.use_rel_pos,
            learnable_position=self.learnable_position,
            add_position=self.add_position,
            layer_preprocess_sequence=self.layer_preprocess_sequence,
            layer_postprocess_sequence=self.layer_postprocess_sequence,
            dropout=self.dropout,
            scaled=self.scaled,
            ff_size=self.ff_size,
            mode=self.mode,
            dvc=self.dvc,
            max_seq_len=self.max_seq_len,
            pe=self.pe,
            pe_ss=self.pe_ss,  #这些是embedding
            pe_se=self.pe_se,
            pe_es=self.pe_es,
            pe_ee=self.pe_ee,
            k_proj=self.k_proj,
            q_proj=self.q_proj,
            v_proj=self.v_proj,
            r_proj=self.r_proj,
            attn_ff=self.attn_ff,
            ff_activate=self.ff_activate,
            lattice=True,
            four_pos_fusion=self.four_pos_fusion,
            four_pos_fusion_shared=self.four_pos_fusion_shared,
            use_pytorch_dropout=self.use_pytorch_dropout)

        self.output = nn.Linear(self.hidden_size, self.label_size)
        if self.self_supervised:
            self.output_self_supervised = nn.Linear(self.hidden_size,
                                                    len(vocabs['char']))
            print('self.output_self_supervised:{}'.format(
                self.output_self_supervised.weight.size()))
        self.crf = get_crf_zero_init(self.label_size)
        self.loss_func = nn.CrossEntropyLoss(ignore_index=-100)
        self.batch_num = 0
Ejemplo n.º 4
0
    def __init__(self,
                 vocab: Vocabulary,
                 model_dir_or_name: str = 'en',
                 embedding_dim=-1,
                 requires_grad: bool = True,
                 init_method=None,
                 lower=False,
                 dropout=0,
                 word_dropout=0,
                 normalize=False,
                 min_freq=1,
                 **kwargs):
        """

        :param vocab: Vocabulary. 若该项为None则会读取所有的embedding。
        :param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个
            以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。
            如果输入为None则使用embedding_dim的维度随机初始化一个embedding。
        :param int embedding_dim: 随机初始化的embedding的维度,当该值为大于0的值时,将忽略model_dir_or_name。
        :param bool requires_grad: 是否需要gradient. 默认为True
        :param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法, 传入的方法应该接受一个tensor,并
            inplace地修改其值。
        :param bool lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独
            为大写的词语开辟一个vector表示,则将lower设置为False。
        :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
        :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
        :param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。
        :param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。
        :param dict kwarngs: only_train_min_freq, 仅对train中的词语使用min_freq筛选; only_norm_found_vector是否仅对在预训练中找到的词语使用normalize。
        """
        super(StaticEmbedding, self).__init__(vocab,
                                              word_dropout=word_dropout,
                                              dropout=dropout)
        if embedding_dim > 0:
            model_dir_or_name = None

        # 得到cache_path
        if model_dir_or_name is None:
            assert embedding_dim >= 1, "The dimension of embedding should be larger than 1."
            embedding_dim = int(embedding_dim)
            model_path = None
        elif model_dir_or_name.lower() in PRETRAIN_STATIC_FILES:
            model_url = _get_embedding_url('static', model_dir_or_name.lower())
            model_path = cached_path(model_url, name='embedding')
            # 检查是否存在
        elif os.path.isfile(
                os.path.abspath(os.path.expanduser(model_dir_or_name))):
            model_path = os.path.abspath(os.path.expanduser(model_dir_or_name))
        elif os.path.isdir(
                os.path.abspath(os.path.expanduser(model_dir_or_name))):
            model_path = _get_file_name_base_on_postfix(
                os.path.abspath(os.path.expanduser(model_dir_or_name)), '.txt')
        else:
            raise ValueError(f"Cannot recognize {model_dir_or_name}.")

        # 根据min_freq缩小vocab
        truncate_vocab = (vocab.min_freq is None
                          and min_freq > 1) or (vocab.min_freq
                                                and vocab.min_freq < min_freq)
        if truncate_vocab:
            truncated_vocab = deepcopy(vocab)
            truncated_vocab.min_freq = min_freq
            truncated_vocab.word2idx = None
            if lower:  # 如果有lower,将大小写的的freq需要同时考虑到
                lowered_word_count = defaultdict(int)
                for word, count in truncated_vocab.word_count.items():
                    lowered_word_count[word.lower()] += count
                for word in truncated_vocab.word_count.keys():
                    word_count = truncated_vocab.word_count[word]
                    if lowered_word_count[word.lower(
                    )] >= min_freq and word_count < min_freq:
                        truncated_vocab.add_word_lst(
                            [word] * (min_freq - word_count),
                            no_create_entry=truncated_vocab.
                            _is_word_no_create_entry(word))

            # 只限制在train里面的词语使用min_freq筛选
            if kwargs.get('only_train_min_freq',
                          False) and model_dir_or_name is not None:
                for word in truncated_vocab.word_count.keys():
                    if truncated_vocab._is_word_no_create_entry(
                            word
                    ) and truncated_vocab.word_count[word] < min_freq:
                        truncated_vocab.add_word_lst(
                            [word] *
                            (min_freq - truncated_vocab.word_count[word]),
                            no_create_entry=True)
            truncated_vocab.build_vocab()
            truncated_words_to_words = torch.arange(len(vocab)).long()
            for word, index in vocab:
                truncated_words_to_words[index] = truncated_vocab.to_index(
                    word)
            logger.info(
                f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}."
            )
            vocab = truncated_vocab

        self.only_norm_found_vector = kwargs.get('only_norm_found_vector',
                                                 False)
        # 读取embedding
        if lower:
            lowered_vocab = Vocabulary(padding=vocab.padding,
                                       unknown=vocab.unknown)
            for word, index in vocab:
                if vocab._is_word_no_create_entry(word):
                    lowered_vocab.add_word(word.lower(), no_create_entry=True)
                else:
                    lowered_vocab.add_word(word.lower())  # 先加入需要创建entry的
            logger.info(
                f"All word in the vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} "
                f"unique lowered words.")
            if model_path:
                embedding = self._load_with_vocab(model_path,
                                                  vocab=lowered_vocab,
                                                  init_method=init_method)
            else:
                embedding = self._randomly_init_embed(len(vocab),
                                                      embedding_dim,
                                                      init_method)
                self.register_buffer('words_to_words',
                                     torch.arange(len(vocab)).long())
            if lowered_vocab.unknown:
                unknown_idx = lowered_vocab.unknown_idx
            else:
                unknown_idx = embedding.size(0) - 1  # 否则是最后一个为unknow
                self.register_buffer('words_to_words',
                                     torch.arange(len(vocab)).long())
            words_to_words = torch.full((len(vocab), ),
                                        fill_value=unknown_idx).long()
            for word, index in vocab:
                if word not in lowered_vocab:
                    word = word.lower()
                    if word not in lowered_vocab and lowered_vocab._is_word_no_create_entry(
                            word):
                        continue  # 如果不需要创建entry,已经默认unknown了
                words_to_words[index] = self.words_to_words[
                    lowered_vocab.to_index(word)]
            self.register_buffer('words_to_words', words_to_words)
            self._word_unk_index = lowered_vocab.unknown_idx  # 替换一下unknown的index
        else:
            if model_path:
                embedding = self._load_with_vocab(model_path,
                                                  vocab=vocab,
                                                  init_method=init_method)
            else:
                embedding = self._randomly_init_embed(len(vocab),
                                                      embedding_dim,
                                                      init_method)
                self.register_buffer('words_to_words',
                                     torch.arange(len(vocab)).long())
        if not self.only_norm_found_vector and normalize:
            embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12)

        if truncate_vocab:
            for i in range(len(truncated_words_to_words)):
                index_in_truncated_vocab = truncated_words_to_words[i]
                truncated_words_to_words[i] = self.words_to_words[
                    index_in_truncated_vocab]
            del self.words_to_words
            self.register_buffer('words_to_words', truncated_words_to_words)
        self.embedding = nn.Embedding(num_embeddings=embedding.shape[0],
                                      embedding_dim=embedding.shape[1],
                                      padding_idx=vocab.padding_idx,
                                      max_norm=None,
                                      norm_type=2,
                                      scale_grad_by_freq=False,
                                      sparse=False,
                                      _weight=embedding)
        self._embed_size = self.embedding.weight.size(1)
        self.requires_grad = requires_grad
        self.dropout = MyDropout(dropout)
Ejemplo n.º 5
0
    def __init__(self, lattice_embed, bigram_embed, hidden_size, label_size,
                 num_heads, num_layers,
                 use_abs_pos, use_rel_pos, learnable_position, add_position,
                 layer_preprocess_sequence, layer_postprocess_sequence,
                 ff_size=-1, scaled=True, dropout=None, use_bigram=True, mode=collections.defaultdict(bool),
                 dvc=None, vocabs=None,
                 rel_pos_shared=True, max_seq_len=-1, k_proj=True, q_proj=True, v_proj=True, r_proj=True,
                 self_supervised=False, attn_ff=True, pos_norm=False, ff_activate='relu', rel_pos_init=0,
                 abs_pos_fusion_func='concat', embed_dropout_pos='0',
                 four_pos_shared=True, four_pos_fusion=None, four_pos_fusion_shared=True, bert_embedding=None,
                 new_tag_scheme=False, span_loss_alpha=1.0, ple_channel_num=1, use_ple_lstm=False):
        '''
        :param rel_pos_init: 如果是0,那么从-max_len到max_len的相对位置编码矩阵就按0-2*max_len来初始化,
        如果是1,那么就按-max_len,max_len来初始化

        :param embed_dropout_pos: 如果是0,就直接在embed后dropout,是1就在embed变成hidden size之后再dropout,
        是2就在绝对位置加上之后dropout
        '''
        super().__init__()

        self.use_bert = False
        if bert_embedding is not None:
            self.use_bert = True
            self.bert_embedding = bert_embedding

        self.four_pos_fusion_shared = four_pos_fusion_shared
        self.mode = mode
        self.four_pos_shared = four_pos_shared
        self.abs_pos_fusion_func = abs_pos_fusion_func
        self.lattice_embed = lattice_embed
        self.bigram_embed = bigram_embed
        self.hidden_size = hidden_size
        self.label_size = label_size
        self.num_heads = num_heads
        self.num_layers = num_layers
        # self.relative_position = relative_position
        self.use_abs_pos = use_abs_pos
        self.use_rel_pos = use_rel_pos
        if self.use_rel_pos:
            assert four_pos_fusion is not None
        self.four_pos_fusion = four_pos_fusion
        self.learnable_position = learnable_position
        self.add_position = add_position
        self.rel_pos_shared = rel_pos_shared
        self.self_supervised = self_supervised
        self.vocabs = vocabs
        self.attn_ff = attn_ff
        self.pos_norm = pos_norm
        self.ff_activate = ff_activate
        self.rel_pos_init = rel_pos_init
        self.embed_dropout_pos = embed_dropout_pos

        # if self.relative_position:
        #     print('现在还不支持相对编码!')
        #     exit(1208)

        # if self.add_position:
        #     print('暂时只支持位置编码的concat模式')
        #     exit(1208)

        if self.use_rel_pos and max_seq_len < 0:
            print_info('max_seq_len should be set if relative position encode')
            exit(1208)

        self.max_seq_len = max_seq_len

        self.k_proj = k_proj
        self.q_proj = q_proj
        self.v_proj = v_proj
        self.r_proj = r_proj

        self.pe = None

        if self.use_abs_pos:
            self.abs_pos_encode = Absolute_SE_Position_Embedding(self.abs_pos_fusion_func,
                                                                 self.hidden_size, learnable=self.learnable_position,
                                                                 mode=self.mode,
                                                                 pos_norm=self.pos_norm)

        if self.use_rel_pos:
            pe = get_embedding(max_seq_len, hidden_size, rel_pos_init=self.rel_pos_init)
            pe_sum = pe.sum(dim=-1, keepdim=True)
            if self.pos_norm:
                with torch.no_grad():
                    pe = pe / pe_sum
            self.pe = nn.Parameter(pe, requires_grad=self.learnable_position)
            if self.four_pos_shared:
                self.pe_ss = self.pe
                self.pe_se = self.pe
                self.pe_es = self.pe
                self.pe_ee = self.pe
            else:
                self.pe_ss = nn.Parameter(copy.deepcopy(pe), requires_grad=self.learnable_position)
                self.pe_se = nn.Parameter(copy.deepcopy(pe), requires_grad=self.learnable_position)
                self.pe_es = nn.Parameter(copy.deepcopy(pe), requires_grad=self.learnable_position)
                self.pe_ee = nn.Parameter(copy.deepcopy(pe), requires_grad=self.learnable_position)
        else:
            self.pe = None
            self.pe_ss = None
            self.pe_se = None
            self.pe_es = None
            self.pe_ee = None

        # if self.add_position:
        #     print('现在还不支持位置编码通过concat的方式加入')
        #     exit(1208)

        self.layer_preprocess_sequence = layer_preprocess_sequence
        self.layer_postprocess_sequence = layer_postprocess_sequence
        if ff_size == -1:
            ff_size = self.hidden_size
        self.ff_size = ff_size
        self.scaled = scaled
        if dvc == None:
            dvc = 'cpu'
        self.dvc = torch.device(dvc)
        if dropout is None:
            self.dropout = collections.defaultdict(int)
        else:
            self.dropout = dropout
        self.use_bigram = use_bigram

        if self.use_bigram:
            self.bigram_size = self.bigram_embed.embedding.weight.size(1)
            self.char_input_size = self.lattice_embed.embedding.weight.size(
                1) + self.bigram_embed.embedding.weight.size(1)
        else:
            self.char_input_size = self.lattice_embed.embedding.weight.size(1)

        if self.use_bert:
            self.char_input_size += self.bert_embedding._embed_size

        self.lex_input_size = self.lattice_embed.embedding.weight.size(1)

        self.embed_dropout = MyDropout(self.dropout['embed'])
        self.gaz_dropout = MyDropout(self.dropout['gaz'])

        self.char_proj = nn.Linear(self.char_input_size, self.hidden_size)
        self.lex_proj = nn.Linear(self.lex_input_size, self.hidden_size)

        self.encoder = Transformer_Encoder(self.hidden_size, self.num_heads, self.num_layers,
                                           relative_position=self.use_rel_pos,
                                           learnable_position=self.learnable_position,
                                           add_position=self.add_position,
                                           layer_preprocess_sequence=self.layer_preprocess_sequence,
                                           layer_postprocess_sequence=self.layer_postprocess_sequence,
                                           dropout=self.dropout,
                                           scaled=self.scaled,
                                           ff_size=self.ff_size,
                                           mode=self.mode,
                                           dvc=self.dvc,
                                           max_seq_len=self.max_seq_len,
                                           pe=self.pe,
                                           pe_ss=self.pe_ss,
                                           pe_se=self.pe_se,
                                           pe_es=self.pe_es,
                                           pe_ee=self.pe_ee,
                                           k_proj=self.k_proj,
                                           q_proj=self.q_proj,
                                           v_proj=self.v_proj,
                                           r_proj=self.r_proj,
                                           attn_ff=self.attn_ff,
                                           ff_activate=self.ff_activate,
                                           lattice=True,
                                           four_pos_fusion=self.four_pos_fusion,
                                           four_pos_fusion_shared=self.four_pos_fusion_shared)

        self.output_dropout = MyDropout(self.dropout['output'])

        self.output = nn.Linear(self.hidden_size, self.label_size)
        if self.self_supervised:
            self.output_self_supervised = nn.Linear(self.hidden_size, len(vocabs['char']))
            print('self.output_self_supervised:{}'.format(self.output_self_supervised.weight.size()))

        self.span_label_size = len(vocabs['span_label'])
        self.attr_label_size = len(vocabs['attr_label'])
        self.new_tag_scheme = new_tag_scheme

        if self.new_tag_scheme:
            self.crf = get_crf_zero_init(self.span_label_size, include_start_end_trans=True)
            weight = torch.FloatTensor([1.0 if vocabs['attr_label'].to_word(i) != ATTR_NULL_TAG else 0.1
                                        for i in range(self.attr_label_size)])
            self.attr_criterion = nn.CrossEntropyLoss(reduction='none')
            self.ple = PLE(hidden_size=hidden_size, span_label_size=self.span_label_size,
                           attr_label_size=self.attr_label_size, dropout_rate=0.1, experts_layers=2, experts_num=1,
                           ple_dropout=0.1, use_ple_lstm=use_ple_lstm)
            self.span_loss_alpha = span_loss_alpha
            self.ple_channel_num = ple_channel_num
            self.encoder_list = clones(self.encoder, self.ple_channel_num)
        else:
            self.crf = get_crf_zero_init(self.label_size, include_start_end_trans=True)
            self.loss_func = nn.CrossEntropyLoss(ignore_index=-100)
        self.steps = 0