Example #1
0
 def __init__(self,bert_embedding,label_size,vocabs,after_bert):
     super().__init__()
     self.after_bert = after_bert
     self.bert_embedding = bert_embedding
     self.label_size = label_size
     self.vocabs = vocabs
     self.hidden_size = bert_embedding._embed_size
     self.output = nn.Linear(self.hidden_size,self.label_size)
     self.crf = get_crf_zero_init(self.label_size)
     if self.after_bert == 'lstm':
         self.lstm = LSTM(bert_embedding._embed_size,bert_embedding._embed_size//2,
                          bidirectional=True)
     self.dropout = MyDropout(0.5)
    def __init__(self,
                 lattice_embed,
                 bigram_embed,
                 hidden_size,
                 label_size,
                 num_heads,
                 num_layers,
                 use_abs_pos,
                 use_rel_pos,
                 learnable_position,
                 add_position,
                 layer_preprocess_sequence,
                 layer_postprocess_sequence,
                 ff_size=-1,
                 scaled=True,
                 dropout=None,
                 use_bigram=True,
                 mode=collections.defaultdict(bool),
                 dvc=None,
                 vocabs=None,
                 rel_pos_shared=True,
                 max_seq_len=-1,
                 k_proj=True,
                 q_proj=True,
                 v_proj=True,
                 r_proj=True,
                 self_supervised=False,
                 attn_ff=True,
                 pos_norm=False,
                 ff_activate='relu',
                 rel_pos_init=0,
                 abs_pos_fusion_func='concat',
                 embed_dropout_pos='0',
                 four_pos_shared=True,
                 four_pos_fusion=None,
                 four_pos_fusion_shared=True,
                 bert_embedding=None,
                 use_pytorch_dropout=False):
        '''
        :param rel_pos_init: 如果是0,那么从-max_len到max_len的相对位置编码矩阵就按0-2*max_len来初始化,
        如果是1,那么就按-max_len,max_len来初始化

        :param embed_dropout_pos: 如果是0,就直接在embed后dropout,是1就在embed变成hidden size之后再dropout,
        是2就在绝对位置加上之后dropout
        '''
        super().__init__()
        self.use_pytorch_dropout = use_pytorch_dropout
        self.four_pos_fusion_shared = four_pos_fusion_shared
        self.mode = mode
        self.four_pos_shared = four_pos_shared
        self.abs_pos_fusion_func = abs_pos_fusion_func
        self.lattice_embed = lattice_embed
        self.bigram_embed = bigram_embed
        self.hidden_size = hidden_size
        self.label_size = label_size
        self.num_heads = num_heads
        self.num_layers = num_layers
        # self.relative_position = relative_position
        self.use_abs_pos = use_abs_pos
        self.use_rel_pos = use_rel_pos
        if self.use_rel_pos:
            assert four_pos_fusion is not None
        self.four_pos_fusion = four_pos_fusion
        self.learnable_position = learnable_position
        self.add_position = add_position
        self.rel_pos_shared = rel_pos_shared
        self.self_supervised = self_supervised
        self.vocabs = vocabs
        self.attn_ff = attn_ff
        self.pos_norm = pos_norm
        self.ff_activate = ff_activate
        self.rel_pos_init = rel_pos_init
        self.embed_dropout_pos = embed_dropout_pos

        # if self.relative_position:
        #     print('现在还不支持相对编码!')
        #     exit(1208)

        # if self.add_position:
        #     print('暂时只支持位置编码的concat模式')
        #     exit(1208)

        if self.use_rel_pos and max_seq_len < 0:
            print_info('max_seq_len should be set if relative position encode')
            exit(1208)

        self.max_seq_len = max_seq_len

        self.k_proj = k_proj
        self.q_proj = q_proj
        self.v_proj = v_proj
        self.r_proj = r_proj

        self.pe = None

        if self.use_abs_pos:
            self.abs_pos_encode = Absolute_SE_Position_Embedding(
                self.abs_pos_fusion_func,
                self.hidden_size,
                learnable=self.learnable_position,
                mode=self.mode,
                pos_norm=self.pos_norm)

        if self.use_rel_pos:
            pe = get_embedding(max_seq_len,
                               hidden_size,
                               rel_pos_init=self.rel_pos_init)
            pe_sum = pe.sum(dim=-1, keepdim=True)
            if self.pos_norm:
                with torch.no_grad():
                    pe = pe / pe_sum
            self.pe = nn.Parameter(pe, requires_grad=self.learnable_position)
            if self.four_pos_shared:
                self.pe_ss = self.pe
                self.pe_se = self.pe
                self.pe_es = self.pe
                self.pe_ee = self.pe
            else:
                self.pe_ss = nn.Parameter(
                    copy.deepcopy(pe), requires_grad=self.learnable_position)
                self.pe_se = nn.Parameter(
                    copy.deepcopy(pe), requires_grad=self.learnable_position)
                self.pe_es = nn.Parameter(
                    copy.deepcopy(pe), requires_grad=self.learnable_position)
                self.pe_ee = nn.Parameter(
                    copy.deepcopy(pe), requires_grad=self.learnable_position)
        else:
            self.pe = None
            self.pe_ss = None
            self.pe_se = None
            self.pe_es = None
            self.pe_ee = None

        self.layer_preprocess_sequence = layer_preprocess_sequence
        self.layer_postprocess_sequence = layer_postprocess_sequence
        if ff_size == -1:
            ff_size = self.hidden_size
        self.ff_size = ff_size
        self.scaled = scaled
        if dvc == None:
            dvc = 'cpu'
        self.dvc = torch.device(dvc)
        if dropout is None:
            self.dropout = collections.defaultdict(int)
        else:
            self.dropout = dropout
        self.use_bigram = use_bigram

        if self.use_bigram:
            self.bigram_size = self.bigram_embed.embedding.weight.size(1)
            self.char_input_size = self.lattice_embed.embedding.weight.size(
                1) + self.bigram_embed.embedding.weight.size(1)
        else:
            self.char_input_size = self.lattice_embed.embedding.weight.size(1)

        self.lex_input_size = self.lattice_embed.embedding.weight.size(1)

        if use_pytorch_dropout:
            self.embed_dropout = nn.Dropout(self.dropout['embed'])
            self.gaz_dropout = nn.Dropout(self.dropout['gaz'])
            self.output_dropout = nn.Dropout(self.dropout['output'])
        else:
            self.embed_dropout = MyDropout(self.dropout['embed'])
            self.gaz_dropout = MyDropout(self.dropout['gaz'])
            self.output_dropout = MyDropout(self.dropout['output'])

        self.char_proj = nn.Linear(self.char_input_size, self.hidden_size)
        self.lex_proj = nn.Linear(self.lex_input_size, self.hidden_size)

        self.encoder = Transformer_Encoder(
            self.hidden_size,
            self.num_heads,
            self.num_layers,
            relative_position=self.use_rel_pos,
            learnable_position=self.learnable_position,
            add_position=self.add_position,
            layer_preprocess_sequence=self.layer_preprocess_sequence,
            layer_postprocess_sequence=self.layer_postprocess_sequence,
            dropout=self.dropout,
            scaled=self.scaled,
            ff_size=self.ff_size,
            mode=self.mode,
            dvc=self.dvc,
            max_seq_len=self.max_seq_len,
            pe=self.pe,
            pe_ss=self.pe_ss,  #这些是embedding
            pe_se=self.pe_se,
            pe_es=self.pe_es,
            pe_ee=self.pe_ee,
            k_proj=self.k_proj,
            q_proj=self.q_proj,
            v_proj=self.v_proj,
            r_proj=self.r_proj,
            attn_ff=self.attn_ff,
            ff_activate=self.ff_activate,
            lattice=True,
            four_pos_fusion=self.four_pos_fusion,
            four_pos_fusion_shared=self.four_pos_fusion_shared,
            use_pytorch_dropout=self.use_pytorch_dropout)

        self.output = nn.Linear(self.hidden_size, self.label_size)
        if self.self_supervised:
            self.output_self_supervised = nn.Linear(self.hidden_size,
                                                    len(vocabs['char']))
            print('self.output_self_supervised:{}'.format(
                self.output_self_supervised.weight.size()))
        self.crf = get_crf_zero_init(self.label_size)
        self.loss_func = nn.CrossEntropyLoss(ignore_index=-100)
        self.batch_num = 0
    def __init__(self, lattice_embed, bigram_embed, hidden_size, label_size,
                 num_heads, num_layers,
                 use_abs_pos, use_rel_pos, learnable_position, add_position,
                 layer_preprocess_sequence, layer_postprocess_sequence,
                 ff_size=-1, scaled=True, dropout=None, use_bigram=True, mode=collections.defaultdict(bool),
                 dvc=None, vocabs=None,
                 rel_pos_shared=True, max_seq_len=-1, k_proj=True, q_proj=True, v_proj=True, r_proj=True,
                 self_supervised=False, attn_ff=True, pos_norm=False, ff_activate='relu', rel_pos_init=0,
                 abs_pos_fusion_func='concat', embed_dropout_pos='0',
                 four_pos_shared=True, four_pos_fusion=None, four_pos_fusion_shared=True, bert_embedding=None,
                 new_tag_scheme=False, span_loss_alpha=1.0, ple_channel_num=1, use_ple_lstm=False):
        '''
        :param rel_pos_init: 如果是0,那么从-max_len到max_len的相对位置编码矩阵就按0-2*max_len来初始化,
        如果是1,那么就按-max_len,max_len来初始化

        :param embed_dropout_pos: 如果是0,就直接在embed后dropout,是1就在embed变成hidden size之后再dropout,
        是2就在绝对位置加上之后dropout
        '''
        super().__init__()

        self.use_bert = False
        if bert_embedding is not None:
            self.use_bert = True
            self.bert_embedding = bert_embedding

        self.four_pos_fusion_shared = four_pos_fusion_shared
        self.mode = mode
        self.four_pos_shared = four_pos_shared
        self.abs_pos_fusion_func = abs_pos_fusion_func
        self.lattice_embed = lattice_embed
        self.bigram_embed = bigram_embed
        self.hidden_size = hidden_size
        self.label_size = label_size
        self.num_heads = num_heads
        self.num_layers = num_layers
        # self.relative_position = relative_position
        self.use_abs_pos = use_abs_pos
        self.use_rel_pos = use_rel_pos
        if self.use_rel_pos:
            assert four_pos_fusion is not None
        self.four_pos_fusion = four_pos_fusion
        self.learnable_position = learnable_position
        self.add_position = add_position
        self.rel_pos_shared = rel_pos_shared
        self.self_supervised = self_supervised
        self.vocabs = vocabs
        self.attn_ff = attn_ff
        self.pos_norm = pos_norm
        self.ff_activate = ff_activate
        self.rel_pos_init = rel_pos_init
        self.embed_dropout_pos = embed_dropout_pos

        # if self.relative_position:
        #     print('现在还不支持相对编码!')
        #     exit(1208)

        # if self.add_position:
        #     print('暂时只支持位置编码的concat模式')
        #     exit(1208)

        if self.use_rel_pos and max_seq_len < 0:
            print_info('max_seq_len should be set if relative position encode')
            exit(1208)

        self.max_seq_len = max_seq_len

        self.k_proj = k_proj
        self.q_proj = q_proj
        self.v_proj = v_proj
        self.r_proj = r_proj

        self.pe = None

        if self.use_abs_pos:
            self.abs_pos_encode = Absolute_SE_Position_Embedding(self.abs_pos_fusion_func,
                                                                 self.hidden_size, learnable=self.learnable_position,
                                                                 mode=self.mode,
                                                                 pos_norm=self.pos_norm)

        if self.use_rel_pos:
            pe = get_embedding(max_seq_len, hidden_size, rel_pos_init=self.rel_pos_init)
            pe_sum = pe.sum(dim=-1, keepdim=True)
            if self.pos_norm:
                with torch.no_grad():
                    pe = pe / pe_sum
            self.pe = nn.Parameter(pe, requires_grad=self.learnable_position)
            if self.four_pos_shared:
                self.pe_ss = self.pe
                self.pe_se = self.pe
                self.pe_es = self.pe
                self.pe_ee = self.pe
            else:
                self.pe_ss = nn.Parameter(copy.deepcopy(pe), requires_grad=self.learnable_position)
                self.pe_se = nn.Parameter(copy.deepcopy(pe), requires_grad=self.learnable_position)
                self.pe_es = nn.Parameter(copy.deepcopy(pe), requires_grad=self.learnable_position)
                self.pe_ee = nn.Parameter(copy.deepcopy(pe), requires_grad=self.learnable_position)
        else:
            self.pe = None
            self.pe_ss = None
            self.pe_se = None
            self.pe_es = None
            self.pe_ee = None

        # if self.add_position:
        #     print('现在还不支持位置编码通过concat的方式加入')
        #     exit(1208)

        self.layer_preprocess_sequence = layer_preprocess_sequence
        self.layer_postprocess_sequence = layer_postprocess_sequence
        if ff_size == -1:
            ff_size = self.hidden_size
        self.ff_size = ff_size
        self.scaled = scaled
        if dvc == None:
            dvc = 'cpu'
        self.dvc = torch.device(dvc)
        if dropout is None:
            self.dropout = collections.defaultdict(int)
        else:
            self.dropout = dropout
        self.use_bigram = use_bigram

        if self.use_bigram:
            self.bigram_size = self.bigram_embed.embedding.weight.size(1)
            self.char_input_size = self.lattice_embed.embedding.weight.size(
                1) + self.bigram_embed.embedding.weight.size(1)
        else:
            self.char_input_size = self.lattice_embed.embedding.weight.size(1)

        if self.use_bert:
            self.char_input_size += self.bert_embedding._embed_size

        self.lex_input_size = self.lattice_embed.embedding.weight.size(1)

        self.embed_dropout = MyDropout(self.dropout['embed'])
        self.gaz_dropout = MyDropout(self.dropout['gaz'])

        self.char_proj = nn.Linear(self.char_input_size, self.hidden_size)
        self.lex_proj = nn.Linear(self.lex_input_size, self.hidden_size)

        self.encoder = Transformer_Encoder(self.hidden_size, self.num_heads, self.num_layers,
                                           relative_position=self.use_rel_pos,
                                           learnable_position=self.learnable_position,
                                           add_position=self.add_position,
                                           layer_preprocess_sequence=self.layer_preprocess_sequence,
                                           layer_postprocess_sequence=self.layer_postprocess_sequence,
                                           dropout=self.dropout,
                                           scaled=self.scaled,
                                           ff_size=self.ff_size,
                                           mode=self.mode,
                                           dvc=self.dvc,
                                           max_seq_len=self.max_seq_len,
                                           pe=self.pe,
                                           pe_ss=self.pe_ss,
                                           pe_se=self.pe_se,
                                           pe_es=self.pe_es,
                                           pe_ee=self.pe_ee,
                                           k_proj=self.k_proj,
                                           q_proj=self.q_proj,
                                           v_proj=self.v_proj,
                                           r_proj=self.r_proj,
                                           attn_ff=self.attn_ff,
                                           ff_activate=self.ff_activate,
                                           lattice=True,
                                           four_pos_fusion=self.four_pos_fusion,
                                           four_pos_fusion_shared=self.four_pos_fusion_shared)

        self.output_dropout = MyDropout(self.dropout['output'])

        self.output = nn.Linear(self.hidden_size, self.label_size)
        if self.self_supervised:
            self.output_self_supervised = nn.Linear(self.hidden_size, len(vocabs['char']))
            print('self.output_self_supervised:{}'.format(self.output_self_supervised.weight.size()))

        self.span_label_size = len(vocabs['span_label'])
        self.attr_label_size = len(vocabs['attr_label'])
        self.new_tag_scheme = new_tag_scheme

        if self.new_tag_scheme:
            self.crf = get_crf_zero_init(self.span_label_size, include_start_end_trans=True)
            weight = torch.FloatTensor([1.0 if vocabs['attr_label'].to_word(i) != ATTR_NULL_TAG else 0.1
                                        for i in range(self.attr_label_size)])
            self.attr_criterion = nn.CrossEntropyLoss(reduction='none')
            self.ple = PLE(hidden_size=hidden_size, span_label_size=self.span_label_size,
                           attr_label_size=self.attr_label_size, dropout_rate=0.1, experts_layers=2, experts_num=1,
                           ple_dropout=0.1, use_ple_lstm=use_ple_lstm)
            self.span_loss_alpha = span_loss_alpha
            self.ple_channel_num = ple_channel_num
            self.encoder_list = clones(self.encoder, self.ple_channel_num)
        else:
            self.crf = get_crf_zero_init(self.label_size, include_start_end_trans=True)
            self.loss_func = nn.CrossEntropyLoss(ignore_index=-100)
        self.steps = 0