def __init__(self, d_model, heads, d_ff, dropout, topic=False, topic_dim=300, split_noise=False): super(TransformerDecoderLayer, self).__init__() self.self_attn = MultiHeadedAttention(heads, d_model, dropout=dropout) self.context_attn = MultiHeadedAttention(heads, d_model, dropout=dropout, topic=topic, topic_dim=topic_dim, split_noise=split_noise) self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) self.drop = nn.Dropout(dropout) mask = self._get_attn_subsequent_mask(MAX_SIZE) # Register self.mask as a buffer in TransformerDecoderLayer, so # it gets TransformerDecoderLayer's cuda behavior automatically. self.register_buffer('mask', mask)
def __init__(self, d_model, heads, d_ff, dropout): super(TransformerDecoderLayer, self).__init__() self.self_attn = MultiHeadedAttention(heads, d_model, dropout=dropout) self.enc_attn = MultiHeadedAttention(heads, d_model, dropout=dropout) self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) self.layer_norm1 = nn.LayerNorm(d_model, eps=1e-6) self.layer_norm2 = nn.LayerNorm(d_model, eps=1e-6) self.dropout = nn.Dropout(dropout)
def __init__(self, args, device, checkpoint=None, checkpoint_ext=None, checkpoint_abs=None): super(HybridSummarizer, self).__init__() self.args = args self.args self.device = device self.extractor = ExtSummarizer(args, device, checkpoint_ext) self.abstractor = AbsSummarizer(args, device, checkpoint_abs) self.context_attn = MultiHeadedAttention( head_count=self.args.dec_heads, model_dim=self.args.dec_hidden_size, dropout=self.args.dec_dropout, need_distribution=True) self.v = nn.Parameter(torch.Tensor(1, self.args.dec_hidden_size * 3)) self.bv = nn.Parameter(torch.Tensor(1)) self.attn_lin = nn.Linear(self.args.dec_hidden_size, self.args.dec_hidden_size) if self.args.hybrid_loss: self.ext_loss_fun = torch.nn.BCELoss(reduction='none') if self.args.hybrid_connector: self.p_sen = nn.Linear(self.args.dec_hidden_size, 1) # When Bert is testing, he loads directly. if checkpoint is not None: self.load_state_dict(checkpoint['model'], strict=True) print("checkpoint loaded!") else: self.attn_lin.weight.data.normal_(mean=0.0, std=0.02) nn.init.xavier_uniform_(self.v) nn.init.constant_(self.bv, 0) if self.args.hybrid_connector: for module in self.p_sen.modules(): # print(each) if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() for module in self.context_attn.modules(): # print(each) if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() self.to(device)
def __init__(self, d_model, heads, d_ff, dropout): super(TransformerDecoderLayer, self).__init__() self.self_attn = MultiHeadedAttention(heads, d_model, dropout=dropout) self.context_attn = MultiHeadedAttention(heads, d_model, dropout=dropout) self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) self.drop = nn.Dropout(dropout) mask = self._get_attn_subsequent_mask(5000) self.register_buffer('mask', mask)
def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings): super(TransformerDecoder, self).__init__() # Basic attributes. self.decoder_type = 'transformer' self.num_layers = num_layers self.embeddings = embeddings self.pos_emb = PositionalEncoding(dropout, self.embeddings.embedding_dim) # self.context_attn_graph = MultiHeadedAttention( heads, d_model, dropout=dropout) self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) self.drop_3 = nn.Dropout(dropout) self.layer_norm_3 = nn.LayerNorm(d_model, eps=1e-6) # Build TransformerDecoder. self.transformer_layers = nn.ModuleList( [TransformerDecoderLayer(d_model, heads, d_ff, dropout) for _ in range(num_layers)]) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.att_weight_c = nn.Linear(self.embeddings.embedding_dim, 1) self.att_weight_q = nn.Linear(self.embeddings.embedding_dim, 1) self.att_weight_cq = nn.Linear(self.embeddings.embedding_dim, 1) self.graph_act = gelu self.graph_aware = nn.Linear(self.embeddings.embedding_dim*3, self.embeddings.embedding_dim) self.graph_drop = nn.Dropout(dropout) self.linear_filter = nn.Linear(d_model*2, 1) self.fix_top = torch.tensor((torch.arange(512,0,-1).type(torch.FloatTensor)/512).\ unsqueeze(0).unsqueeze(0).expand(8, 512, -1)).to(self.get_device()) self.fix_top.requires_grad = True self.fix_top = torch.nn.Parameter(self.fix_top, requires_grad=True) self.register_parameter("fix_top", self.fix_top)
def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings, vocab_size): super(Z_TransformerDecoder, self).__init__() # Basic attributes. self.decoder_type = 'transformer' self.num_layers = num_layers self.embeddings = embeddings self.pos_emb = PositionalEncoding(dropout,self.embeddings.embedding_dim) self.vocab_size = vocab_size if COPY: self.copy_attn = MultiHeadedAttention( 1, d_model, dropout=dropout) # Build TransformerDecoder. self.transformer_layers = nn.ModuleList( [Z_TransformerDecoderLayer(d_model, heads, d_ff, dropout) for _ in range(num_layers)]) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
class HybridSummarizer(nn.Module): def __init__(self, args, device, checkpoint = None, checkpoint_ext = None, checkpoint_abs = None): super(HybridSummarizer, self).__init__() self.args = args self.args self.device = device self.extractor = ExtSummarizer(args, device, checkpoint_ext) # self.abstractor = PGTransformers(modules, consts, options) self.abstractor = AbsSummarizer(args, device, checkpoint_abs) self.context_attn = MultiHeadedAttention(head_count = self.args.dec_heads, model_dim =self.args.dec_hidden_size, dropout=self.args.dec_dropout, need_distribution = True) self.v = nn.Parameter(torch.Tensor(1, self.args.dec_hidden_size * 3)) self.bv = nn.Parameter(torch.Tensor(1)) self.attn_lin = nn.Linear(self.args.dec_hidden_size, self.args.dec_hidden_size) if self.args.hybrid_loss: self.ext_loss_fun = torch.nn.BCELoss(reduction='none') if self.args.hybrid_connector: self.p_sen = nn.Linear(self.args.dec_hidden_size, 1) if checkpoint is not None: self.load_state_dict(checkpoint['model'], strict=True) print("checkpoint is loading !") else: self.attn_lin.weight.data.normal_(mean=0.0, std=0.02) nn.init.xavier_uniform_(self.v) nn.init.constant_(self.bv, 0) if self.args.hybrid_connector: for module in self.p_sen.modules(): # print(each) if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() for module in self.context_attn.modules(): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() self.to(device) def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls, labels = None): if labels is not None and self.args.oracle: ext_scores = ((labels.float(), + 0.1) / 1.3) * mask_cls.float() else: if labels is None: with torch.no_grad(): ext_scores, _, sent_vec = self.extractor(src, segs, clss, mask_src, mask_cls) else: ext_scores, _, sent_vec = self.extractor(src, segs, clss, mask_src, mask_cls) ext_loss = self.ext_loss_fun(ext_scores, labels.float()) ext_loss = ext_loss * mask_cls.float() # [batchsize * (tgt_len - 1) * hidden_size] # projected into the probability distribution of vocab_size from hidden state. decoder_outputs, encoder_state, y_emb = self.abstractor(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls) src_pad_mask = (1 - mask_src).unsqueeze(1).repeat(1, tgt.size(1) - 1, 1) context_vector, attn_dist = self.context_attn(encoder_state, encoder_state, decoder_outputs, mask=src_pad_mask, type="context") if self.args.hybrid_connector: sorted_scores, sorted_scores_idx = torch.sort(ext_scores, dim=1, descending=True) # for top-3 select num. select_num = min(3, mask_cls.size(1)) # 每个句子单独算一个值出来。 # 这里有一堆东西,但是是把选出的前三个句子和对应评分加权,方便下边求和。 selected_sent_vec = tuple([(sorted_scores[i][:select_num].unsqueeze(0).transpose(0,1) * sent_vec[i,tuple(sorted_scores_idx[i][:select_num])]).unsqueeze(0) for i, each in enumerate(sorted_scores_idx)]) selected_sent_vec = torch.cat(selected_sent_vec, dim=0) selected_sent_vec = selected_sent_vec.sum(dim=1) E_sel = self.p_sen(selected_sent_vec) ext_scores = ext_scores * E_sel g = torch.sigmoid(F.linear(torch.cat([decoder_outputs, y_emb, context_vector], -1), self.v, self.bv)) xids = src.unsqueeze(0).repeat(tgt.size(1) - 1, 1, 1).transpose(0,1) xids = xids * mask_tgt.unsqueeze(2)[:,:-1,:].long() # mask characters such as CLS um len0 = src.size(1) len0 = torch.Tensor([[len0]]).repeat(src.size(0), 1).long().to('cuda') clss_up = torch.cat((clss, len0), dim=1) sent_len = (clss_up[:, 1:] - clss) * mask_cls.long() for i in range(mask_cls.size(0)): for j in range(mask_cls.size(1)): if sent_len[i][j] < 0: sent_len[i][j] += src.size(1) ext_scores_0 = ext_scores.unsqueeze(1).transpose(1,2).repeat(1,1, src.size(1)) for i in range(clss.size(0)): tmp_vec = ext_scores_0[i, 0, :sent_len[i][0].int()] for j in range(1, clss.size(1)): tmp_vec = torch.cat((tmp_vec, ext_scores_0[i, j, :sent_len[i][j].int()]), dim=0) if i == 0: ext_scores_new = tmp_vec.unsqueeze(0) else: ext_scores_new = torch.cat((ext_scores_new, tmp_vec.unsqueeze(0)), dim=0) ext_scores_new = ext_scores_new * mask_src.float() attn_dist = attn_dist * (ext_scores_new + 1).unsqueeze(1) # Weighted sum formula. attn_dist = attn_dist / attn_dist.sum(dim=2).unsqueeze(2) ext_dist = Variable(torch.zeros(tgt.size(0), tgt.size(1) - 1, self.abstractor.bert.model.config.vocab_size).to(self.device)) ext_vocab_prob = ext_dist.scatter_add(2, xids, (1 - g) * mask_tgt.unsqueeze(2)[:,:-1,:].float() * attn_dist) * mask_tgt.unsqueeze(2)[:,:-1,:].float() if self.args.hybrid_loss: return decoder_outputs, None, (ext_vocab_prob, g, ext_loss) else: return decoder_outputs, None, (ext_vocab_prob, g)
class HybridSummarizer(nn.Module): def __init__(self, args, device, checkpoint=None, checkpoint_ext=None, checkpoint_abs=None): super(HybridSummarizer, self).__init__() self.args = args self.args self.device = device self.extractor = ExtSummarizer(args, device, checkpoint_ext) self.abstractor = AbsSummarizer(args, device, checkpoint_abs) self.context_attn = MultiHeadedAttention( head_count=self.args.dec_heads, model_dim=self.args.dec_hidden_size, dropout=self.args.dec_dropout, need_distribution=True) self.v = nn.Parameter(torch.Tensor(1, self.args.dec_hidden_size * 3)) self.bv = nn.Parameter(torch.Tensor(1)) self.attn_lin = nn.Linear(self.args.dec_hidden_size, self.args.dec_hidden_size) if self.args.hybrid_loss: self.ext_loss_fun = torch.nn.BCELoss(reduction='none') if self.args.hybrid_connector: self.p_sen = nn.Linear(self.args.dec_hidden_size, 1) # When Bert is testing, he loads directly. if checkpoint is not None: self.load_state_dict(checkpoint['model'], strict=True) print("checkpoint loaded!") else: self.attn_lin.weight.data.normal_(mean=0.0, std=0.02) nn.init.xavier_uniform_(self.v) nn.init.constant_(self.bv, 0) if self.args.hybrid_connector: for module in self.p_sen.modules(): # print(each) if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() for module in self.context_attn.modules(): # print(each) if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() self.to(device) def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls, labels=None): if labels is not None and self.args.oracle: ext_scores = ((labels.float(), +0.1) / 1.3) * mask_cls.float() else: # w if labels is None: with torch.no_grad(): ext_scores, _, sent_vec = self.extractor( src, segs, clss, mask_src, mask_cls) else: ext_scores, _, sent_vec = self.extractor( src, segs, clss, mask_src, mask_cls) ext_loss = self.ext_loss_fun(ext_scores, labels.float()) ext_loss = ext_loss * mask_cls.float() decoder_outputs, encoder_state, y_emb = self.abstractor( src, tgt, segs, clss, mask_src, mask_tgt, mask_cls) src_pad_mask = (1 - mask_src).unsqueeze(1).repeat( 1, tgt.size(1) - 1, 1) context_vector, attn_dist = self.context_attn(encoder_state, encoder_state, decoder_outputs, mask=src_pad_mask, type="context") if self.args.hybrid_connector: select_num = min(3, mask_cls.size(1)) selected_sent_vec = tuple([ (sorted_scores[i][:select_num].unsqueeze(0).transpose(0, 1) * sent_vec[i, tuple(sorted_scores_idx[i][:select_num])] ).unsqueeze(0) for i, each in enumerate(sorted_scores_idx) ]) selected_sent_vec = torch.cat(selected_sent_vec, dim=0) selected_sent_vec = selected_sent_vec.sum(dim=1) E_sel = self.p_sen(selected_sent_vec) ext_scores = ext_scores * E_sel if torch.isnan(decoder_outputs[0][0][0]): print("ops, decoder_outputs!") print("src = ", src.size()) print(src) print("tgt = ", tgt.size()) print(tgt) # # # segs是每个词属于哪句话 print("segs = ", segs.size()) print(segs) # # clss 是每个句子的起点位置 print("clss = ", clss.size()) print(clss) print("mask_src = ", mask_src.size()) print(mask_src) print("mask_cls = ", mask_cls.size()) print(mask_cls) print("decoder_outputs ", decoder_outputs.size()) print(decoder_outputs) print("y_emb ", y_emb) print(y_emb) print("context_vector ", context_vector.size()) print(context_vector) exit() if torch.isnan(y_emb[0][0][0]): print("ops, yemb!") print("src = ", src.size()) print(src) print("tgt = ", tgt.size()) print(tgt) # # # segs是每个词属于哪句话 print("segs = ", segs.size()) print(segs) # # clss 是每个句子的起点位置 print("clss = ", clss.size()) print(clss) print("mask_src = ", mask_src.size()) print(mask_src) print("mask_cls = ", mask_cls.size()) print(mask_cls) print("decoder_outputs ", decoder_outputs.size()) print(decoder_outputs) print("y_emb ", y_emb) print(y_emb) print("context_vector ", context_vector.size()) print(context_vector) exit() if torch.isnan(context_vector[0][0][0]): print("ops, context_vector!") print("src = ", src.size()) print(src) print("tgt = ", tgt.size()) print(tgt) # # # segs是每个词属于哪句话 print("segs = ", segs.size()) print(segs) # # clss 是每个句子的起点位置 print("clss = ", clss.size()) print(clss) print("mask_src = ", mask_src.size()) print(mask_src) print("mask_cls = ", mask_cls.size()) print(mask_cls) print("decoder_outputs ", decoder_outputs.size()) print(decoder_outputs) print("y_emb ", y_emb) print(y_emb) print("context_vector ", context_vector.size()) print(context_vector) exit() g = torch.sigmoid( F.linear(torch.cat([decoder_outputs, y_emb, context_vector], -1), self.v, self.bv)) if torch.isnan(g[0][0]): print("ops!, g") print("src = ", src.size()) print(src) print("tgt = ", tgt.size()) print(tgt) # # # segs是每个词属于哪句话 print("segs = ", segs.size()) print(segs) # # clss 是每个句子的起点位置 print("clss = ", clss.size()) print(clss) print("mask_src = ", mask_src.size()) print(mask_src) print("mask_cls = ", mask_cls.size()) print(mask_cls) print("decoder_outputs ", decoder_outputs.size()) print(decoder_outputs) print("y_emb ", y_emb) print(y_emb) print("context_vector ", context_vector.size()) print(context_vector) print("g ", g.size()) print(g) exit() xids = src.unsqueeze(0).repeat(tgt.size(1) - 1, 1, 1).transpose(0, 1) xids = xids * mask_tgt.unsqueeze(2)[:, :-1, :].long() len0 = src.size(1) len0 = torch.Tensor([[len0]]).repeat(src.size(0), 1).long().to('cuda') clss_up = torch.cat((clss, len0), dim=1) sent_len = (clss_up[:, 1:] - clss) * mask_cls.long() for i in range(mask_cls.size(0)): for j in range(mask_cls.size(1)): if sent_len[i][j] < 0: sent_len[i][j] += src.size(1) ext_scores_0 = ext_scores.unsqueeze(1).transpose(1, 2).repeat( 1, 1, src.size(1)) for i in range(clss.size(0)): tmp_vec = ext_scores_0[i, 0, :sent_len[i][0].int()] for j in range(1, clss.size(1)): tmp_vec = torch.cat( (tmp_vec, ext_scores_0[i, j, :sent_len[i][j].int()]), dim=0) if i == 0: ext_scores_new = tmp_vec.unsqueeze(0) else: ext_scores_new = torch.cat( (ext_scores_new, tmp_vec.unsqueeze(0)), dim=0) ext_scores_new = ext_scores_new * mask_src.float() attn_dist = attn_dist * (ext_scores_new + 1).unsqueeze(1) attn_dist = attn_dist / attn_dist.sum(dim=2).unsqueeze(2) ext_dist = Variable( torch.zeros(tgt.size(0), tgt.size(1) - 1, self.abstractor.bert.model.config.vocab_size).to( self.device)) ext_vocab_prob = ext_dist.scatter_add( 2, xids, (1 - g) * mask_tgt.unsqueeze(2)[:, :-1, :].float() * attn_dist) * mask_tgt.unsqueeze(2)[:, :-1, :].float() if self.args.hybrid_loss: return decoder_outputs, None, (ext_vocab_prob, g, ext_loss) else: return decoder_outputs, None, (ext_vocab_prob, g)
def __init__(self, args, device, vocab_size, product_size, vocab_words, word_dists=None): super(ItemTransformerRanker, self).__init__() self.args = args self.device = device self.train_review_only = args.train_review_only self.embedding_size = args.embedding_size self.vocab_words = vocab_words self.word_dists = None if word_dists is not None: self.word_dists = torch.tensor(word_dists, device=device) self.prod_dists = torch.ones(product_size, device=device) self.prod_pad_idx = product_size self.word_pad_idx = vocab_size - 1 self.seg_pad_idx = 3 self.emb_dropout = args.dropout self.pretrain_emb_dir = None if os.path.exists(args.pretrain_emb_dir): self.pretrain_emb_dir = args.pretrain_emb_dir self.pretrain_up_emb_dir = None if os.path.exists(args.pretrain_up_emb_dir): self.pretrain_up_emb_dir = args.pretrain_up_emb_dir self.dropout_layer = nn.Dropout(p=args.dropout) self.product_emb = nn.Embedding(product_size + 1, self.embedding_size, padding_idx=self.prod_pad_idx) if args.sep_prod_emb: self.hist_product_emb = nn.Embedding(product_size + 1, self.embedding_size, padding_idx=self.prod_pad_idx) ''' else: pretrain_product_emb_path = os.path.join(self.pretrain_up_emb_dir, "product_emb.txt") pretrained_weights = load_user_item_embeddings(pretrain_product_emb_path) pretrained_weights.append([0.] * len(pretrained_weights[0])) self.product_emb = nn.Embedding.from_pretrained(torch.FloatTensor(pretrained_weights), padding_idx=self.prod_pad_idx) ''' self.product_bias = nn.Parameter(torch.zeros(product_size + 1), requires_grad=True) self.word_bias = nn.Parameter(torch.zeros(vocab_size), requires_grad=True) if self.pretrain_emb_dir is not None: word_emb_fname = "word_emb.txt.gz" #for query and target words in pv and pvc pretrain_word_emb_path = os.path.join(self.pretrain_emb_dir, word_emb_fname) word_index_dic, pretrained_weights = load_pretrain_embeddings( pretrain_word_emb_path) word_indices = torch.tensor( [0] + [word_index_dic[x] for x in self.vocab_words[1:]] + [self.word_pad_idx]) #print(len(word_indices)) #print(word_indices.cpu().tolist()) pretrained_weights = torch.FloatTensor(pretrained_weights) self.word_embeddings = nn.Embedding.from_pretrained( pretrained_weights[word_indices], padding_idx=self.word_pad_idx) #vectors of padding idx will not be updated else: self.word_embeddings = nn.Embedding(vocab_size, self.embedding_size, padding_idx=self.word_pad_idx) if self.args.model_name == "item_transformer": self.transformer_encoder = TransformerEncoder( self.embedding_size, args.ff_size, args.heads, args.dropout, args.inter_layers) #if self.args.model_name == "ZAM" or self.args.model_name == "AEM": else: self.attention_encoder = MultiHeadedAttention( args.heads, self.embedding_size, args.dropout) if args.query_encoder_name == "fs": self.query_encoder = FSEncoder(self.embedding_size, self.emb_dropout) else: self.query_encoder = AVGEncoder(self.embedding_size, self.emb_dropout) self.seg_embeddings = nn.Embedding(4, self.embedding_size, padding_idx=self.seg_pad_idx) #for each q,u,i #Q, previous purchases of u, current available reviews for i, padding value #self.logsoftmax = torch.nn.LogSoftmax(dim = -1) self.bce_logits_loss = torch.nn.BCEWithLogitsLoss( reduction='none') #by default it's mean self.initialize_parameters(logger) #logger self.to(device) #change model in place self.item_loss = 0 self.ps_loss = 0
def __init__(self, args, device, checkpoint=None, bert_from_extractive=None): super(AbsSummarizer, self).__init__() self.args = args self.device = device self.bert = Bert(args.large, args.temp_dir, args.finetune_bert) if bert_from_extractive is not None: self.bert.model.load_state_dict( dict([(n[11:], p) for n, p in bert_from_extractive.items() if n.startswith('bert.model')]), strict=True) if (args.encoder == 'baseline'): bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.enc_hidden_size, num_hidden_layers=args.enc_layers, num_attention_heads=8, intermediate_size=args.enc_ff_size, hidden_dropout_prob=args.enc_dropout, attention_probs_dropout_prob=args.enc_dropout) self.bert.model = BertModel(bert_config) if (args.max_pos > 512): my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size) my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None, :].repeat(args.max_pos - 512, 1) self.bert.model.embeddings.position_embeddings = my_pos_embeddings self.vocab_size = self.bert.model.config.vocab_size self.enc_out_size = self.args.dec_hidden_size if self.args.use_dep: self.enc_out_size += 2 if self.args.use_frame: self.enc_frame = nn.Linear(1, 20) self.frame_attn = MultiHeadedAttention(1, 20, 0.1) self.enc_out_size += 20 self.enc_out = nn.Linear(self.enc_out_size, self.args.dec_hidden_size) self.drop = nn.Dropout(self.args.enc_dropout) self.layer_norm = nn.LayerNorm(self.args.dec_hidden_size, eps=1e-6) tgt_embeddings = nn.Embedding(self.vocab_size, self.args.dec_hidden_size, padding_idx=0) if (self.args.share_emb): tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight) self.decoder = TransformerDecoder( self.args.dec_layers, self.args.dec_hidden_size, heads=self.args.dec_heads, d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings) self.generator = get_generator(self.vocab_size, self.args.dec_hidden_size, device) self.generator[0].weight = self.decoder.embeddings.weight if checkpoint is not None: self.load_state_dict(checkpoint['model'], strict=True) else: for module in self.decoder.modules(): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() for p in self.generator.parameters(): if p.dim() > 1: xavier_uniform_(p) else: p.data.zero_() if (args.use_bert_emb): tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight) self.decoder.embeddings = tgt_embeddings self.generator[0].weight = self.decoder.embeddings.weight self.to(device)