def __init__( self, char_embed, bigram_embed, hidden_size, label_size, num_heads, num_layers, use_abs_pos, use_rel_pos, learnable_position, add_position, layer_preprocess_sequence, layer_postprocess_sequence, ff_size=-1, scaled=True, dropout=None, use_bigram=True, mode=collections.defaultdict(bool), dvc=None, vocabs=None, rel_pos_shared=True, max_seq_len=-1, k_proj=True, q_proj=True, v_proj=True, r_proj=True, self_supervised=False, attn_ff=True, pos_norm=False, ff_activate='relu', rel_pos_init=0, abs_pos_fusion_func='concat', embed_dropout_pos='0', ): ''' :param rel_pos_init: 如果是0,那么从-max_len到max_len的相对位置编码矩阵就按0-2*max_len来初始化, 如果是1,那么就按-max_len,max_len来初始化 ''' super().__init__() self.abs_pos_fusion_func = abs_pos_fusion_func self.embed_dropout_pos = embed_dropout_pos self.mode = mode self.char_embed = char_embed self.bigram_embed = bigram_embed self.hidden_size = hidden_size self.label_size = label_size self.num_heads = num_heads self.num_layers = num_layers self.use_abs_pos = use_abs_pos self.use_rel_pos = use_rel_pos # self.relative_position = relative_position self.learnable_position = learnable_position self.add_position = add_position self.rel_pos_shared = rel_pos_shared self.self_supervised = self_supervised self.vocabs = vocabs self.attn_ff = attn_ff self.pos_norm = pos_norm self.ff_activate = ff_activate self.rel_pos_init = rel_pos_init if self.use_rel_pos and max_seq_len < 0: print_info('max_seq_len should be set if relative position encode') exit(1208) self.max_seq_len = max_seq_len self.k_proj = k_proj self.q_proj = q_proj self.v_proj = v_proj self.r_proj = r_proj self.pe = None if self.use_abs_pos: self.pos_encode = Absolute_Position_Embedding( self.abs_pos_fusion_func, self.hidden_size, learnable=self.learnable_position, mode=self.mode, pos_norm=self.pos_norm) if self.use_rel_pos: pe = get_embedding(max_seq_len, hidden_size, rel_pos_init=self.rel_pos_init) pe_sum = pe.sum(dim=-1, keepdim=True) if self.pos_norm: with torch.no_grad(): pe = pe / pe_sum self.pe = nn.Parameter(pe, requires_grad=self.learnable_position) else: self.pe = None # if self.relative_position: # print('现在还不支持相对编码!') # exit(1208) # if not self.add_position: # print('现在还不支持位置编码通过concat的方式加入') # exit(1208) self.layer_preprocess_sequence = layer_preprocess_sequence self.layer_postprocess_sequence = layer_postprocess_sequence if ff_size == -1: ff_size = self.hidden_size self.ff_size = ff_size self.scaled = scaled if dvc == None: dvc = 'cpu' self.dvc = torch.device(dvc) if dropout is None: self.dropout = collections.defaultdict(int) else: self.dropout = dropout self.use_bigram = use_bigram if self.use_bigram: self.input_size = self.char_embed.embedding.weight.size( 1) + self.bigram_embed.embedding.weight.size(1) else: self.input_size = self.char_embed.embedding.weight.size(1) self.embed_dropout = MyDropout(self.dropout['embed']) self.w_proj = nn.Linear(self.input_size, self.hidden_size) self.encoder = Transformer_Encoder( self.hidden_size, self.num_heads, self.num_layers, relative_position=self.use_rel_pos, learnable_position=self.learnable_position, add_position=self.add_position, layer_preprocess_sequence=self.layer_preprocess_sequence, layer_postprocess_sequence=self.layer_postprocess_sequence, dropout=self.dropout, scaled=self.scaled, ff_size=self.ff_size, mode=self.mode, dvc=self.dvc, max_seq_len=self.max_seq_len, pe=self.pe, k_proj=self.k_proj, q_proj=self.q_proj, v_proj=self.v_proj, r_proj=self.r_proj, attn_ff=self.attn_ff, ff_activate=self.ff_activate, ) self.output_dropout = MyDropout(self.dropout['output']) self.output = nn.Linear(self.hidden_size, self.label_size) if self.self_supervised: self.output_self_supervised = nn.Linear(self.hidden_size, len(vocabs['char'])) print('self.output_self_supervised:{}'.format( self.output_self_supervised.weight.size())) self.crf = get_crf_zero_init(self.label_size) self.loss_func = nn.CrossEntropyLoss(ignore_index=-100)
def __init__(self,lattice_embed, bigram_embed, hidden_size, label_size, num_heads, num_layers, use_abs_pos,use_rel_pos, learnable_position,add_position, layer_preprocess_sequence, layer_postprocess_sequence, ff_size=-1, scaled=True , dropout=None,use_bigram=True,mode=collections.defaultdict(bool), # use_bigram 代表什么意思? dvc=None,vocabs=None, rel_pos_shared=True, max_seq_len=-1, k_proj=True, q_proj=True, v_proj=True, r_proj=True, self_supervised=False, attn_ff=True, pos_norm=False, ff_activate='relu', rel_pos_init=0, abs_pos_fusion_func='concat', embed_dropout_pos='0', four_pos_shared=True, four_pos_fusion=None, four_pos_fusion_shared=True, bert_embedding=None): ''' :param rel_pos_init: 如果是0,那么从-max_len到max_len的相对位置编码矩阵就按0-2*max_len来初始化, 如果是1,那么就按-max_len,max_len来初始化 :param embed_dropout_pos: 如果是0,就直接在embed后dropout,是1就在embed变成hidden size之后再dropout, 是2就在绝对位置加上之后dropout ''' super().__init__() self.use_bert = False # 这里为什么会设为False ?仅仅是初始化? # 是根据bert_embedding 来反向推导是否使用bert 来做embedding if bert_embedding is not None: self.use_bert = True self.bert_embedding = bert_embedding self.four_pos_fusion_shared = four_pos_fusion_shared self.mode = mode self.four_pos_shared = four_pos_shared self.abs_pos_fusion_func = abs_pos_fusion_func self.lattice_embed = lattice_embed self.bigram_embed = bigram_embed self.hidden_size = hidden_size self.label_size = label_size self.num_heads = num_heads self.num_layers = num_layers # self.relative_position = relative_position self.use_abs_pos = use_abs_pos self.use_rel_pos = use_rel_pos if self.use_rel_pos: assert four_pos_fusion is not None self.four_pos_fusion = four_pos_fusion self.learnable_position = learnable_position self.add_position = add_position self.rel_pos_shared = rel_pos_shared self.self_supervised=self_supervised self.vocabs = vocabs self.attn_ff = attn_ff self.pos_norm = pos_norm self.ff_activate = ff_activate self.rel_pos_init = rel_pos_init self.embed_dropout_pos = embed_dropout_pos # if self.relative_position: # print('现在还不支持相对编码!') # exit(1208) # if self.add_position: # print('暂时只支持位置编码的concat模式') # exit(1208) if self.use_rel_pos and max_seq_len < 0: print_info('max_seq_len should be set if relative position encode') exit(1208) self.max_seq_len = max_seq_len self.k_proj = k_proj self.q_proj = q_proj self.v_proj = v_proj self.r_proj = r_proj self.pe = None if self.use_abs_pos: self.abs_pos_encode = Absolute_SE_Position_Embedding(self.abs_pos_fusion_func, self.hidden_size,learnable=self.learnable_position,mode=self.mode, pos_norm=self.pos_norm) if self.use_rel_pos: pe = get_embedding(max_seq_len,hidden_size,rel_pos_init=self.rel_pos_init) pe_sum = pe.sum(dim=-1,keepdim=True) if self.pos_norm: with torch.no_grad(): pe = pe/pe_sum self.pe = nn.Parameter(pe, requires_grad=self.learnable_position) if self.four_pos_shared: self.pe_ss = self.pe self.pe_se = self.pe self.pe_es = self.pe self.pe_ee = self.pe else: self.pe_ss = nn.Parameter(copy.deepcopy(pe),requires_grad=self.learnable_position) self.pe_se = nn.Parameter(copy.deepcopy(pe),requires_grad=self.learnable_position) self.pe_es = nn.Parameter(copy.deepcopy(pe),requires_grad=self.learnable_position) self.pe_ee = nn.Parameter(copy.deepcopy(pe),requires_grad=self.learnable_position) else: self.pe = None self.pe_ss = None self.pe_se = None self.pe_es = None self.pe_ee = None # if self.add_position: # print('现在还不支持位置编码通过concat的方式加入') # exit(1208) self.layer_preprocess_sequence = layer_preprocess_sequence self.layer_postprocess_sequence = layer_postprocess_sequence if ff_size==-1: ff_size = self.hidden_size self.ff_size = ff_size self.scaled = scaled if dvc == None: dvc = 'cpu' self.dvc = torch.device(dvc) if dropout is None: self.dropout = collections.defaultdict(int) else: self.dropout = dropout self.use_bigram = use_bigram if self.use_bigram: self.bigram_size = self.bigram_embed.embedding.weight.size(1) self.char_input_size = self.lattice_embed.embedding.weight.size(1)+self.bigram_embed.embedding.weight.size(1) else: self.char_input_size = self.lattice_embed.embedding.weight.size(1) if self.use_bert: self.char_input_size+=self.bert_embedding._embed_size self.lex_input_size = self.lattice_embed.embedding.weight.size(1) self.embed_dropout = MyDropout(self.dropout['embed']) self.gaz_dropout = MyDropout(self.dropout['gaz']) self.char_proj = nn.Linear(self.char_input_size,self.hidden_size) self.lex_proj = nn.Linear(self.lex_input_size,self.hidden_size) self.encoder = Transformer_Encoder(self.hidden_size,self.num_heads,self.num_layers, relative_position=self.use_rel_pos, learnable_position=self.learnable_position, add_position=self.add_position, layer_preprocess_sequence=self.layer_preprocess_sequence, layer_postprocess_sequence=self.layer_postprocess_sequence, dropout=self.dropout, scaled=self.scaled, ff_size=self.ff_size, mode=self.mode, dvc=self.dvc, max_seq_len=self.max_seq_len, pe=self.pe, pe_ss=self.pe_ss, pe_se=self.pe_se, pe_es=self.pe_es, pe_ee=self.pe_ee, k_proj=self.k_proj, q_proj=self.q_proj, v_proj=self.v_proj, r_proj=self.r_proj, attn_ff=self.attn_ff, ff_activate=self.ff_activate, lattice=True, four_pos_fusion=self.four_pos_fusion, four_pos_fusion_shared=self.four_pos_fusion_shared) self.output_dropout = MyDropout(self.dropout['output']) self.output = nn.Linear(self.hidden_size,self.label_size) if self.self_supervised: self.output_self_supervised = nn.Linear(self.hidden_size,len(vocabs['char'])) print('self.output_self_supervised:{}'.format(self.output_self_supervised.weight.size())) self.crf = get_crf_zero_init(self.label_size) self.loss_func = nn.CrossEntropyLoss(ignore_index=-100)