def __init__(
        self,
        char_embed,
        bigram_embed,
        hidden_size,
        label_size,
        num_heads,
        num_layers,
        use_abs_pos,
        use_rel_pos,
        learnable_position,
        add_position,
        layer_preprocess_sequence,
        layer_postprocess_sequence,
        ff_size=-1,
        scaled=True,
        dropout=None,
        use_bigram=True,
        mode=collections.defaultdict(bool),
        dvc=None,
        vocabs=None,
        rel_pos_shared=True,
        max_seq_len=-1,
        k_proj=True,
        q_proj=True,
        v_proj=True,
        r_proj=True,
        self_supervised=False,
        attn_ff=True,
        pos_norm=False,
        ff_activate='relu',
        rel_pos_init=0,
        abs_pos_fusion_func='concat',
        embed_dropout_pos='0',
    ):
        '''
        :param rel_pos_init: 如果是0,那么从-max_len到max_len的相对位置编码矩阵就按0-2*max_len来初始化,
        如果是1,那么就按-max_len,max_len来初始化
        '''
        super().__init__()
        self.abs_pos_fusion_func = abs_pos_fusion_func
        self.embed_dropout_pos = embed_dropout_pos
        self.mode = mode
        self.char_embed = char_embed
        self.bigram_embed = bigram_embed
        self.hidden_size = hidden_size
        self.label_size = label_size
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.use_abs_pos = use_abs_pos
        self.use_rel_pos = use_rel_pos
        # self.relative_position = relative_position
        self.learnable_position = learnable_position
        self.add_position = add_position
        self.rel_pos_shared = rel_pos_shared
        self.self_supervised = self_supervised
        self.vocabs = vocabs
        self.attn_ff = attn_ff
        self.pos_norm = pos_norm
        self.ff_activate = ff_activate
        self.rel_pos_init = rel_pos_init

        if self.use_rel_pos and max_seq_len < 0:
            print_info('max_seq_len should be set if relative position encode')
            exit(1208)

        self.max_seq_len = max_seq_len

        self.k_proj = k_proj
        self.q_proj = q_proj
        self.v_proj = v_proj
        self.r_proj = r_proj

        self.pe = None
        if self.use_abs_pos:
            self.pos_encode = Absolute_Position_Embedding(
                self.abs_pos_fusion_func,
                self.hidden_size,
                learnable=self.learnable_position,
                mode=self.mode,
                pos_norm=self.pos_norm)

        if self.use_rel_pos:
            pe = get_embedding(max_seq_len,
                               hidden_size,
                               rel_pos_init=self.rel_pos_init)
            pe_sum = pe.sum(dim=-1, keepdim=True)
            if self.pos_norm:
                with torch.no_grad():
                    pe = pe / pe_sum
            self.pe = nn.Parameter(pe, requires_grad=self.learnable_position)
        else:
            self.pe = None

        # if self.relative_position:
        #     print('现在还不支持相对编码!')
        #     exit(1208)

        # if not self.add_position:
        #     print('现在还不支持位置编码通过concat的方式加入')
        #     exit(1208)

        self.layer_preprocess_sequence = layer_preprocess_sequence
        self.layer_postprocess_sequence = layer_postprocess_sequence
        if ff_size == -1:
            ff_size = self.hidden_size
        self.ff_size = ff_size
        self.scaled = scaled
        if dvc == None:
            dvc = 'cpu'
        self.dvc = torch.device(dvc)
        if dropout is None:
            self.dropout = collections.defaultdict(int)
        else:
            self.dropout = dropout
        self.use_bigram = use_bigram

        if self.use_bigram:
            self.input_size = self.char_embed.embedding.weight.size(
                1) + self.bigram_embed.embedding.weight.size(1)
        else:
            self.input_size = self.char_embed.embedding.weight.size(1)

        self.embed_dropout = MyDropout(self.dropout['embed'])
        self.w_proj = nn.Linear(self.input_size, self.hidden_size)
        self.encoder = Transformer_Encoder(
            self.hidden_size,
            self.num_heads,
            self.num_layers,
            relative_position=self.use_rel_pos,
            learnable_position=self.learnable_position,
            add_position=self.add_position,
            layer_preprocess_sequence=self.layer_preprocess_sequence,
            layer_postprocess_sequence=self.layer_postprocess_sequence,
            dropout=self.dropout,
            scaled=self.scaled,
            ff_size=self.ff_size,
            mode=self.mode,
            dvc=self.dvc,
            max_seq_len=self.max_seq_len,
            pe=self.pe,
            k_proj=self.k_proj,
            q_proj=self.q_proj,
            v_proj=self.v_proj,
            r_proj=self.r_proj,
            attn_ff=self.attn_ff,
            ff_activate=self.ff_activate,
        )

        self.output_dropout = MyDropout(self.dropout['output'])

        self.output = nn.Linear(self.hidden_size, self.label_size)
        if self.self_supervised:
            self.output_self_supervised = nn.Linear(self.hidden_size,
                                                    len(vocabs['char']))
            print('self.output_self_supervised:{}'.format(
                self.output_self_supervised.weight.size()))
        self.crf = get_crf_zero_init(self.label_size)
        self.loss_func = nn.CrossEntropyLoss(ignore_index=-100)
Example #2
0
    def __init__(self,lattice_embed, bigram_embed, hidden_size, label_size,
                 num_heads, num_layers,
                 use_abs_pos,use_rel_pos, learnable_position,add_position,
                 layer_preprocess_sequence, layer_postprocess_sequence,
                 ff_size=-1, scaled=True , dropout=None,use_bigram=True,mode=collections.defaultdict(bool), # use_bigram 代表什么意思?
                 dvc=None,vocabs=None,
                 rel_pos_shared=True,
                 max_seq_len=-1,
                 k_proj=True,
                 q_proj=True,
                 v_proj=True,
                 r_proj=True,
                 self_supervised=False,
                 attn_ff=True,
                 pos_norm=False,
                 ff_activate='relu',
                 rel_pos_init=0,
                 abs_pos_fusion_func='concat',
                 embed_dropout_pos='0',
                 four_pos_shared=True,
                 four_pos_fusion=None,
                 four_pos_fusion_shared=True,
                 bert_embedding=None):
        '''
        :param rel_pos_init: 如果是0,那么从-max_len到max_len的相对位置编码矩阵就按0-2*max_len来初始化,
        如果是1,那么就按-max_len,max_len来初始化

        :param embed_dropout_pos: 如果是0,就直接在embed后dropout,是1就在embed变成hidden size之后再dropout,
        是2就在绝对位置加上之后dropout
        '''
        super().__init__()

        self.use_bert = False # 这里为什么会设为False ?仅仅是初始化? # 是根据bert_embedding 来反向推导是否使用bert 来做embedding 
        if bert_embedding is not None:
            self.use_bert = True
            self.bert_embedding = bert_embedding

        self.four_pos_fusion_shared = four_pos_fusion_shared
        self.mode = mode
        self.four_pos_shared = four_pos_shared
        self.abs_pos_fusion_func = abs_pos_fusion_func
        self.lattice_embed = lattice_embed
        self.bigram_embed = bigram_embed
        self.hidden_size = hidden_size
        self.label_size = label_size
        self.num_heads = num_heads
        self.num_layers = num_layers
        # self.relative_position = relative_position
        self.use_abs_pos = use_abs_pos
        self.use_rel_pos = use_rel_pos
        if self.use_rel_pos:
            assert four_pos_fusion is not None
        self.four_pos_fusion = four_pos_fusion
        self.learnable_position = learnable_position
        self.add_position = add_position
        self.rel_pos_shared = rel_pos_shared
        self.self_supervised=self_supervised
        self.vocabs = vocabs
        self.attn_ff = attn_ff
        self.pos_norm = pos_norm
        self.ff_activate = ff_activate
        self.rel_pos_init = rel_pos_init
        self.embed_dropout_pos = embed_dropout_pos

        # if self.relative_position:
        #     print('现在还不支持相对编码!')
        #     exit(1208)

        # if self.add_position:
        #     print('暂时只支持位置编码的concat模式')
        #     exit(1208)

        if self.use_rel_pos and max_seq_len < 0:
            print_info('max_seq_len should be set if relative position encode')
            exit(1208)

        self.max_seq_len = max_seq_len

        self.k_proj = k_proj
        self.q_proj = q_proj
        self.v_proj = v_proj
        self.r_proj = r_proj

        self.pe = None

        if self.use_abs_pos:
            self.abs_pos_encode = Absolute_SE_Position_Embedding(self.abs_pos_fusion_func,
                                        self.hidden_size,learnable=self.learnable_position,mode=self.mode,
                                        pos_norm=self.pos_norm)

        if self.use_rel_pos:
            pe = get_embedding(max_seq_len,hidden_size,rel_pos_init=self.rel_pos_init)
            pe_sum = pe.sum(dim=-1,keepdim=True)
            if self.pos_norm:
                with torch.no_grad():
                    pe = pe/pe_sum
            self.pe = nn.Parameter(pe, requires_grad=self.learnable_position)
            if self.four_pos_shared:
                self.pe_ss = self.pe
                self.pe_se = self.pe
                self.pe_es = self.pe
                self.pe_ee = self.pe
            else:
                self.pe_ss = nn.Parameter(copy.deepcopy(pe),requires_grad=self.learnable_position)
                self.pe_se = nn.Parameter(copy.deepcopy(pe),requires_grad=self.learnable_position)
                self.pe_es = nn.Parameter(copy.deepcopy(pe),requires_grad=self.learnable_position)
                self.pe_ee = nn.Parameter(copy.deepcopy(pe),requires_grad=self.learnable_position)
        else:
            self.pe = None
            self.pe_ss = None
            self.pe_se = None
            self.pe_es = None
            self.pe_ee = None

        # if self.add_position:
        #     print('现在还不支持位置编码通过concat的方式加入')
        #     exit(1208)

        self.layer_preprocess_sequence = layer_preprocess_sequence
        self.layer_postprocess_sequence = layer_postprocess_sequence
        if ff_size==-1:
            ff_size = self.hidden_size
        self.ff_size = ff_size
        self.scaled = scaled
        if dvc == None:
            dvc = 'cpu'
        self.dvc = torch.device(dvc)
        if dropout is None:
            self.dropout = collections.defaultdict(int)
        else:
            self.dropout = dropout
        self.use_bigram = use_bigram

        if self.use_bigram:
            self.bigram_size = self.bigram_embed.embedding.weight.size(1)
            self.char_input_size = self.lattice_embed.embedding.weight.size(1)+self.bigram_embed.embedding.weight.size(1)
        else:
            self.char_input_size = self.lattice_embed.embedding.weight.size(1)

        if self.use_bert:
            self.char_input_size+=self.bert_embedding._embed_size

        self.lex_input_size = self.lattice_embed.embedding.weight.size(1)
        self.embed_dropout = MyDropout(self.dropout['embed'])
        self.gaz_dropout = MyDropout(self.dropout['gaz'])
        self.char_proj = nn.Linear(self.char_input_size,self.hidden_size)
        self.lex_proj = nn.Linear(self.lex_input_size,self.hidden_size)
        self.encoder = Transformer_Encoder(self.hidden_size,self.num_heads,self.num_layers,
                                           relative_position=self.use_rel_pos,
                                           learnable_position=self.learnable_position,
                                           add_position=self.add_position,
                                           layer_preprocess_sequence=self.layer_preprocess_sequence,
                                           layer_postprocess_sequence=self.layer_postprocess_sequence,
                                           dropout=self.dropout,
                                           scaled=self.scaled,
                                           ff_size=self.ff_size,
                                           mode=self.mode,
                                           dvc=self.dvc,
                                           max_seq_len=self.max_seq_len,
                                           pe=self.pe,
                                           pe_ss=self.pe_ss,
                                           pe_se=self.pe_se,
                                           pe_es=self.pe_es,
                                           pe_ee=self.pe_ee,
                                           k_proj=self.k_proj,
                                           q_proj=self.q_proj,
                                           v_proj=self.v_proj,
                                           r_proj=self.r_proj,
                                           attn_ff=self.attn_ff,
                                           ff_activate=self.ff_activate,
                                           lattice=True,
                                           four_pos_fusion=self.four_pos_fusion,
                                           four_pos_fusion_shared=self.four_pos_fusion_shared)

        self.output_dropout = MyDropout(self.dropout['output'])

        self.output = nn.Linear(self.hidden_size,self.label_size)
        if self.self_supervised:
            self.output_self_supervised = nn.Linear(self.hidden_size,len(vocabs['char']))
            print('self.output_self_supervised:{}'.format(self.output_self_supervised.weight.size()))
        self.crf = get_crf_zero_init(self.label_size)
        self.loss_func = nn.CrossEntropyLoss(ignore_index=-100)