def __init__(self, args, device, checkpoint=None, bert_from_extractive=None): super(AbsSummarizer, self).__init__() self.args = args self.device = device self.bert = Bert(args.large, args.temp_dir, args.finetune_bert) if bert_from_extractive is not None: self.bert.model.load_state_dict( dict([(n[11:], p) for n, p in bert_from_extractive.items() if n.startswith('bert.model')]), strict=True) if (args.encoder == 'baseline'): bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.enc_hidden_size, num_hidden_layers=args.enc_layers, num_attention_heads=8, intermediate_size=args.enc_ff_size, hidden_dropout_prob=args.enc_dropout, attention_probs_dropout_prob=args.enc_dropout) self.bert.model = BertModel(bert_config) if(args.max_pos>512): my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size) my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-512,1) self.bert.model.embeddings.position_embeddings = my_pos_embeddings self.vocab_size = self.bert.model.config.vocab_size tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) if (self.args.share_emb): tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight) self.decoder = TransformerDecoder( self.args.dec_layers, self.args.dec_hidden_size, heads=self.args.dec_heads, d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings) self.generator = get_generator(self.vocab_size, self.args.dec_hidden_size, device) self.generator[0].weight = self.decoder.embeddings.weight if checkpoint is not None: self.load_state_dict(checkpoint['model'], strict=True) else: for module in self.decoder.modules(): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() for p in self.generator.parameters(): if p.dim() > 1: xavier_uniform_(p) else: p.data.zero_() if(args.use_bert_emb): tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight) self.decoder.embeddings = tgt_embeddings self.generator[0].weight = self.decoder.embeddings.weight self.to(device)
class BertSummarizer(nn.Module): def __init__(self, checkpoint, device, temp_dir='/temp'): super(BertSummarizer, self).__init__() self.device = device self.bert = Bert(False, temp_dir, True) self.vocab_size = self.bert.model.config.vocab_size tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) self.decoder = TransformerDecoder(6, 768, heads=8, d_ff=2048, dropout=0.2, embeddings=tgt_embeddings) self.generator = get_generator(self.vocab_size, 768, self.device) self.generator[0].weight = self.decoder.embeddings.weight self.load_state_dict(checkpoint['model'], strict=True) self.to(self.device) def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls): top_vec = self.bert(src, segs, mask_src) dec_state = self.decoder.init_decoder_state(src, top_vec) decoder_outputs, state = self.decoder(tgt[:, :-1], top_vec, dec_state) return decoder_outputs, None
def add_dec_adapters(dec_model: TransformerDecoder, config: AdapterConfig) -> TransformerDecoder: # Replace specific layer with adapter-added layer for i in range(len(dec_model.transformer_layers)): dec_model.transformer_layers[i] = adapt_transformer_output(config)( dec_model.transformer_layers[i]) # Freeze all parameters for param in dec_model.parameters(): param.requires_grad = False # Unfreeze trainable parts — layer norms and adapters for name, sub_module in dec_model.named_modules(): if isinstance(sub_module, (Adapter_func, nn.LayerNorm)): for param_name, param in sub_module.named_parameters(): param.requires_grad = True return dec_model
def __init__(self, checkpoint, device, temp_dir='/temp'): super(BertSummarizer, self).__init__() self.device = device self.bert = Bert(False, temp_dir, True) self.vocab_size = self.bert.model.config.vocab_size tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) self.decoder = TransformerDecoder(6, 768, heads=8, d_ff=2048, dropout=0.2, embeddings=tgt_embeddings) self.generator = get_generator(self.vocab_size, 768, self.device) self.generator[0].weight = self.decoder.embeddings.weight self.load_state_dict(checkpoint['model'], strict=True) self.to(self.device)
def __init__(self, d_model, d_ff, heads, dropout, num_inter_layers, use_doc=False, aggr='last'): super(TransformerDecoderSeq, self).__init__() self.aggr = aggr self.decoder = TransformerDecoder(d_model, d_ff, heads, dropout, num_inter_layers) self.dropout = nn.Dropout(dropout) self.use_doc = use_doc if self.use_doc: self.linear_doc1 = nn.Linear(2 * d_model, d_model) self.linear_doc2 = nn.Linear(d_model, 1) self.bilinear = nn.Bilinear(d_model, d_model, 1) self.linear_sent1 = nn.Linear(2 * d_model, d_model) self.linear_sent2 = nn.Linear(d_model, 1) self.linear = nn.Linear(2, 1) self.start_emb = torch.nn.Parameter(torch.rand(1, d_model))
def __init__(self, max_length, enc_vocab, dec_vocab, enc_emb_size, dec_emb_size, enc_units, dec_units, dropout_rate=0.1): super(Transformer, self).__init__() enc_vocab_size = len(enc_vocab.itos) dec_vocab_size = len(dec_vocab.itos) self.encoder_embedding = nn.Sequential( TransformerEmbedding(vocab_size=enc_vocab_size, padding_idx=enc_vocab.stoi["<pad>"], max_length=max_length, embedding_size=enc_emb_size), nn.Dropout(p=dropout_rate)) self.decoder_embedding = nn.Sequential( TransformerEmbedding(vocab_size=dec_vocab_size, padding_idx=enc_vocab.stoi["<pad>"], max_length=max_length, embedding_size=dec_emb_size), nn.Dropout(p=dropout_rate)) self.encoder = nn.Sequential( TransformerEncoder(enc_emb_size, enc_units), nn.Dropout(p=dropout_rate)) self.decoder = TransformerDecoder(dec_emb_size, enc_emb_size, dec_units) self.decoder_drop = nn.Dropout(p=dropout_rate) self.output_layer = nn.Linear(in_features=enc_units[-1], out_features=dec_vocab_size) self.softmax = nn.Softmax(dim=-1)
def __init__(self, args, device, checkpoint=None, bert_from_extractive=None): super(MTLAbsSummarizer, self).__init__() self.args = args self.device = device # Initial Bert self.bert = Bert(args.large, args.temp_dir, args.finetune_bert) # Load ckpt from extractive model if bert_from_extractive is not None: self.bert.model.load_state_dict(dict([ (n[11:], p) for n, p in bert_from_extractive.items() if n.startswith('bert.model') ]), strict=True) # Default Bert if args.encoder == 'baseline': bert_config = BertConfig( self.bert.model.config.vocab_size, hidden_size=args.enc_hidden_size, num_hidden_layers=args.enc_layers, num_attention_heads=8, intermediate_size=args.enc_ff_size, hidden_dropout_prob=args.enc_dropout, attention_probs_dropout_prob=args.enc_dropout) self.bert.model = BertModel(bert_config) # The positional embedding is 512 in original Bert, repeat it for cases > 512 if (args.max_pos > 512): my_pos_embeddings = nn.Embedding( args.max_pos, self.bert.model.config.hidden_size) my_pos_embeddings.weight.data[:512] = \ self.bert.model.embeddings.position_embeddings.weight.data my_pos_embeddings.weight.data[512:] = \ self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-512,1) self.bert.model.embeddings.position_embeddings = my_pos_embeddings self.vocab_size = self.bert.model.config.vocab_size tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) if self.args.share_emb: tgt_embeddings.weight = copy.deepcopy( self.bert.model.embeddings.word_embeddings.weight) # Initial Transformer decoder self.decoder = TransformerDecoder(self.args.dec_layers, self.args.dec_hidden_size, heads=self.args.dec_heads, d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings) # Initial generator self.generator = get_generator(self.vocab_size, self.args.dec_hidden_size, device) self.generator[0].weight = self.decoder.embeddings.weight # Insert Adaptor modules if (args.enc_adapter): enc_hidden_size = self.bert.model.embeddings.word_embeddings.weight.shape[ 1] config = AdapterConfig( hidden_size=enc_hidden_size, adapter_size=args.adapter_size, adapter_act=args.adapter_act, adapter_initializer_range=args.adapter_initializer_range) self.bert.model = add_enc_adapters(self.bert.model, config) self.bert.model = add_layer_norm(self.bert.model, d_model=enc_hidden_size, eps=args.layer_norm_eps) if (args.dec_adapter): config = AdapterConfig( hidden_size=args.dec_hidden_size, adapter_size=args.adapter_size, adapter_act=args.adapter_act, adapter_initializer_range=args.adapter_initializer_range) self.decoder = add_dec_adapters(self.decoder, config) self.decoder = add_layer_norm(self.decoder, d_model=args.dec_hidden_size, eps=args.layer_norm_eps) self.generator[0].weight.requires_grad = False self.generator[0].bias.requires_grad = False # Load ckpt def modify_ckpt_for_enc_adapter(ckpt): """Modifies no-adpter ckpt for adapter-equipped encoder. """ keys_need_modified_enc = [] for k in list(ckpt['model'].keys()): if ('output' in k): keys_need_modified_enc.append(k) for mk in keys_need_modified_enc: ckpt['model'] = OrderedDict([ (mk.replace('output', 'output.self_output'), v) if k == mk else (k, v) for k, v in ckpt['model'].items() ]) def modify_ckpt_for_dec_adapter(ckpt): """Modifies no-adpter ckpt for adapter-equipped decoder. """ keys_need_modified_dec = [] for k in list(ckpt['model'].keys()): if ('layers' in k): keys_need_modified_dec.append(k) for mk in keys_need_modified_dec: p = mk.find('layers.') new_k = mk[:p + 8] + '.dec_layer' + mk[p + 8:] ckpt['model'] = OrderedDict([(new_k, v) if k == mk else (k, v) for k, v in ckpt['model'].items() ]) def identify_unmatched_keys(ckpt1, ckpt2): """Report the unmatched keys in ckpt1 for loading ckpt2 to ckpt1. (debug use) """ fp = open("unmatched_keys.txt", 'w') num = 0 ckpt1_keys = list(ckpt1.keys()) ckpt2_keys = list(ckpt2.keys()) for k in ckpt1_keys: if not (k in ckpt2_keys) and not ("var" in k) and not ( "feed_forward" in k): # NOTE: since var and feed_forward use shared weights from other modules fp.write(k + '\n') print(k) num += 1 print("# of Unmatched Keys: {}".format(num)) fp.close() if checkpoint is not None: if (self.args.enc_adapter and self.args.ckpt_from_no_adapter): modify_ckpt_for_enc_adapter(checkpoint) if (self.args.dec_adapter and self.args.ckpt_from_no_adapter): modify_ckpt_for_dec_adapter(checkpoint) # NOTE: not strict for load model #identify_unmatched_keys(self.state_dict(), checkpoint['model']) # DEBUG self.load_state_dict(checkpoint['model'], strict=False) else: for module in self.decoder.modules(): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() for p in self.generator.parameters(): if p.dim() > 1: xavier_uniform_(p) else: p.data.zero_() if (args.use_bert_emb): tgt_embeddings = nn.Embedding( self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) tgt_embeddings.weight = copy.deepcopy( self.bert.model.embeddings.word_embeddings.weight) self.decoder.embeddings = tgt_embeddings self.generator[0].weight = self.decoder.embeddings.weight self.to(device)
class AbsSummarizer(nn.Module): def __init__(self, args, device, checkpoint=None, bert_from_extractive=None): super(AbsSummarizer, self).__init__() self.args = args self.device = device self.bert = Bert(args.large, args.temp_dir, args.finetune_bert) if bert_from_extractive is not None: self.bert.model.load_state_dict(dict([ (n[11:], p) for n, p in bert_from_extractive.items() if n.startswith('bert.model') ]), strict=True) if (args.encoder == 'baseline'): bert_config = BertConfig( self.bert.model.config.vocab_size, hidden_size=args.enc_hidden_size, num_hidden_layers=args.enc_layers, num_attention_heads=8, intermediate_size=args.enc_ff_size, hidden_dropout_prob=args.enc_dropout, attention_probs_dropout_prob=args.enc_dropout) self.bert.model = BertModel(bert_config) if (args.max_pos > 512): my_pos_embeddings = nn.Embedding( args.max_pos, self.bert.model.config.hidden_size) my_pos_embeddings.weight.data[: 512] = self.bert.model.embeddings.position_embeddings.weight.data my_pos_embeddings.weight.data[ 512:] = self.bert.model.embeddings.position_embeddings.weight.data[ -1][None, :].repeat(args.max_pos - 512, 1) self.bert.model.embeddings.position_embeddings = my_pos_embeddings self.vocab_size = self.bert.model.config.vocab_size tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) if (self.args.share_emb): tgt_embeddings.weight = copy.deepcopy( self.bert.model.embeddings.word_embeddings.weight) self.decoder = TransformerDecoder( self.args.dec_layers, self.args.dec_hidden_size, heads=self.args.dec_heads, d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings, use_universal_transformer=args.dec_universal_trans) self.generator = get_generator(self.vocab_size, self.args.dec_hidden_size, device) self.generator[0].weight = self.decoder.embeddings.weight if checkpoint is not None: self.load_state_dict(checkpoint['model'], strict=True) else: for module in self.decoder.modules(): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() for p in self.generator.parameters(): if p.dim() > 1: xavier_uniform_(p) else: p.data.zero_() if (args.use_bert_emb): tgt_embeddings = nn.Embedding( self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) tgt_embeddings.weight = copy.deepcopy( self.bert.model.embeddings.word_embeddings.weight) self.decoder.embeddings = tgt_embeddings self.generator[0].weight = self.decoder.embeddings.weight self.to(device) def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls): # here src, tgt, mask_src directly goes into the BERT Model # not sure what we can change here, and how to add out linguistic features in it. # Therefore, we will now focus on changing things in decoder as it is trained from scratch. top_vec = self.bert(src, segs, mask_src) dec_state = self.decoder.init_decoder_state(src, top_vec) decoder_outputs, state = self.decoder(tgt[:, :-1], top_vec, dec_state) return decoder_outputs, None
def __init__(self, args, device, checkpoint=None, from_extractive=None, symbols=None): super(AbsSummarizer, self).__init__() self.args = args self.device = device self.symbols = symbols # TODO: 根据args.encoder是bert还是xlnet进行区分, 构建encoder self.pre_model = pre_models(args.encoder) # 选出bert或者xlnet的类 self.encoder = self.pre_model(args, args.large, args.temp_dir, args.finetune_encoder, self.symbols) # encoder is bert or xlnet # self.decoder = XLNet(args.large, args.temp_dir, args.finetune_encoder) # decoder is xlnet if args.max_pos > 512: if args.encoder == 'bert': my_pos_embeddings = nn.Embedding(args.max_pos, self.encoder.model.config.hidden_size) my_pos_embeddings.weight.data[:512] = self.encoder.model.embeddings.position_embeddings.weight.data my_pos_embeddings.weight.data[512:] = self.encoder.model.embeddings.position_embeddings.weight.data[-1][ None, :].repeat(args.max_pos - 512, 1) self.encoder.model.embeddings.position_embeddings = my_pos_embeddings if from_extractive is not None: self.encoder.model.load_state_dict( dict([(n[11:], p) for n, p in from_extractive.items() if n.startswith(args.encoder + '.model')]), strict=True) self.vocab_size = self.encoder.model.config.vocab_size tgt_embeddings = nn.Embedding(self.vocab_size, self.encoder.model.config.hidden_size, padding_idx=0) if self.args.share_emb: tgt_embeddings.weight = copy.deepcopy(self.encoder.model.word_embedding.weight) # TODO: create decoder, options: TransformerDecoder, XLNet, GPT-2 self.decoder = TransformerDecoder( self.args.dec_layers, self.args.dec_hidden_size, heads=self.args.dec_heads, d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings) # TODO: create generator, options: GPT-2, XLNet self.generator = get_generator(self.vocab_size, self.args.dec_hidden_size, device) self.generator[0].weight = self.decoder.embeddings.weight if checkpoint is not None: self.load_state_dict(checkpoint['model'], strict=True) else: for module in self.decoder.modules(): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() for p in self.generator.parameters(): if p.dim() > 1: xavier_uniform_(p) else: p.data.zero_() if args.use_pre_emb: if args.encoder == 'bert': tgt_embeddings = nn.Embedding(self.vocab_size, self.encoder.model.config.hidden_size, padding_idx=0) tgt_embeddings.weight = copy.deepcopy(self.encoder.model.embeddings.word_embeddings.weight) if args.encoder == 'xlnet': tgt_embeddings = nn.Embedding(self.vocab_size, self.encoder.model.config.d_model, padding_idx=0) tgt_embeddings.weight = copy.deepcopy(self.encoder.model.word_embedding.weight) self.decoder.embeddings = tgt_embeddings self.generator[0].weight = self.decoder.embeddings.weight self.to(device)
class MTLAbsSummarizer(nn.Module): def __init__(self, args, device, checkpoint=None, bert_from_extractive=None): super(MTLAbsSummarizer, self).__init__() self.args = args self.device = device # Initial Bert self.bert = Bert(args.large, args.temp_dir, args.finetune_bert) # Load ckpt from extractive model if bert_from_extractive is not None: self.bert.model.load_state_dict(dict([ (n[11:], p) for n, p in bert_from_extractive.items() if n.startswith('bert.model') ]), strict=True) # Default Bert if args.encoder == 'baseline': bert_config = BertConfig( self.bert.model.config.vocab_size, hidden_size=args.enc_hidden_size, num_hidden_layers=args.enc_layers, num_attention_heads=8, intermediate_size=args.enc_ff_size, hidden_dropout_prob=args.enc_dropout, attention_probs_dropout_prob=args.enc_dropout) self.bert.model = BertModel(bert_config) # The positional embedding is 512 in original Bert, repeat it for cases > 512 if (args.max_pos > 512): my_pos_embeddings = nn.Embedding( args.max_pos, self.bert.model.config.hidden_size) my_pos_embeddings.weight.data[:512] = \ self.bert.model.embeddings.position_embeddings.weight.data my_pos_embeddings.weight.data[512:] = \ self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-512,1) self.bert.model.embeddings.position_embeddings = my_pos_embeddings self.vocab_size = self.bert.model.config.vocab_size tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) if self.args.share_emb: tgt_embeddings.weight = copy.deepcopy( self.bert.model.embeddings.word_embeddings.weight) # Initial Transformer decoder self.decoder = TransformerDecoder(self.args.dec_layers, self.args.dec_hidden_size, heads=self.args.dec_heads, d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings) # Initial generator self.generator = get_generator(self.vocab_size, self.args.dec_hidden_size, device) self.generator[0].weight = self.decoder.embeddings.weight # Insert Adaptor modules if (args.enc_adapter): enc_hidden_size = self.bert.model.embeddings.word_embeddings.weight.shape[ 1] config = AdapterConfig( hidden_size=enc_hidden_size, adapter_size=args.adapter_size, adapter_act=args.adapter_act, adapter_initializer_range=args.adapter_initializer_range) self.bert.model = add_enc_adapters(self.bert.model, config) self.bert.model = add_layer_norm(self.bert.model, d_model=enc_hidden_size, eps=args.layer_norm_eps) if (args.dec_adapter): config = AdapterConfig( hidden_size=args.dec_hidden_size, adapter_size=args.adapter_size, adapter_act=args.adapter_act, adapter_initializer_range=args.adapter_initializer_range) self.decoder = add_dec_adapters(self.decoder, config) self.decoder = add_layer_norm(self.decoder, d_model=args.dec_hidden_size, eps=args.layer_norm_eps) self.generator[0].weight.requires_grad = False self.generator[0].bias.requires_grad = False # Load ckpt def modify_ckpt_for_enc_adapter(ckpt): """Modifies no-adpter ckpt for adapter-equipped encoder. """ keys_need_modified_enc = [] for k in list(ckpt['model'].keys()): if ('output' in k): keys_need_modified_enc.append(k) for mk in keys_need_modified_enc: ckpt['model'] = OrderedDict([ (mk.replace('output', 'output.self_output'), v) if k == mk else (k, v) for k, v in ckpt['model'].items() ]) def modify_ckpt_for_dec_adapter(ckpt): """Modifies no-adpter ckpt for adapter-equipped decoder. """ keys_need_modified_dec = [] for k in list(ckpt['model'].keys()): if ('layers' in k): keys_need_modified_dec.append(k) for mk in keys_need_modified_dec: p = mk.find('layers.') new_k = mk[:p + 8] + '.dec_layer' + mk[p + 8:] ckpt['model'] = OrderedDict([(new_k, v) if k == mk else (k, v) for k, v in ckpt['model'].items() ]) def identify_unmatched_keys(ckpt1, ckpt2): """Report the unmatched keys in ckpt1 for loading ckpt2 to ckpt1. (debug use) """ fp = open("unmatched_keys.txt", 'w') num = 0 ckpt1_keys = list(ckpt1.keys()) ckpt2_keys = list(ckpt2.keys()) for k in ckpt1_keys: if not (k in ckpt2_keys) and not ("var" in k) and not ( "feed_forward" in k): # NOTE: since var and feed_forward use shared weights from other modules fp.write(k + '\n') print(k) num += 1 print("# of Unmatched Keys: {}".format(num)) fp.close() if checkpoint is not None: if (self.args.enc_adapter and self.args.ckpt_from_no_adapter): modify_ckpt_for_enc_adapter(checkpoint) if (self.args.dec_adapter and self.args.ckpt_from_no_adapter): modify_ckpt_for_dec_adapter(checkpoint) # NOTE: not strict for load model #identify_unmatched_keys(self.state_dict(), checkpoint['model']) # DEBUG self.load_state_dict(checkpoint['model'], strict=False) else: for module in self.decoder.modules(): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() for p in self.generator.parameters(): if p.dim() > 1: xavier_uniform_(p) else: p.data.zero_() if (args.use_bert_emb): tgt_embeddings = nn.Embedding( self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) tgt_embeddings.weight = copy.deepcopy( self.bert.model.embeddings.word_embeddings.weight) self.decoder.embeddings = tgt_embeddings self.generator[0].weight = self.decoder.embeddings.weight self.to(device) def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls): """Forward process. Args: src (tensor(batch, max_src_len_batch)): Source token ids. tgt (tensor(batch, max_tgt_len_batch)): Target token ids. segs (tensor(batch, max_src_len_batch)): Segement id (0 or 1) to speparate source sentences. clss (tensor(batch, max_cls_num_batch)): the position of [CLS] token. mask_src (tensor(batch, max_src_len_batch)) Mask (0 or 1) for source padding tokens. mask_tgt (tensor(batch, max_tgt_len_batch)) Mask (0 or 1) for target padding tokens. mask_cls (tensor(batch, max_cls_num_batch)): Mask (0 or 1) for [CLS] position. Returns: A tuple of variable: decoder_outputs (tensor(batch, max_tgt_len_batch, dec_hidden_dim)): The hidden states from decoder. top_vec (tensor(batch, max_src_len_batch, enc_hidden_dim)): The hidden states from encoder. """ # top_vec -> tensor(batch, max_src_len_batch, enc_hidden_dim) top_vec = self.bert(src, segs, mask_src) # dec_state -> models.decoder.TransformerDecoderState dec_state = self.decoder.init_decoder_state(src, top_vec) # decoder_outputs -> tensor(batch, max_tgt_len_batch, dec_hidden_dim) decoder_outputs, state = self.decoder(tgt[:, :-1], top_vec, dec_state) return decoder_outputs, top_vec # [For Inner Loop] def _cascade_fast_weights_grad(self, fast_weights): """Sets fast-weight mode for adapter and layer norm modules. """ offset = 0 for name, sub_module in self.named_modules(): if isinstance( sub_module, (Adapter_func, LayerNorm_func)) and sub_module.trainable: param_num = len(sub_module._parameters) setattr(sub_module, 'fast_weights_flag', True) delattr(sub_module, 'fast_weights') setattr(sub_module, 'fast_weights', fast_weights[offset:offset + param_num]) offset += param_num return offset # [For Outer Loop] def _clean_fast_weights_mode(self): """Cleans fast-weight mode for adapter and layer norm modules. """ module_num = 0 for name, sub_module in self.named_modules(): if isinstance( sub_module, (Adapter_func, LayerNorm_func)) and sub_module.trainable: setattr(sub_module, 'fast_weights_flag', False) delattr(sub_module, 'fast_weights') setattr(sub_module, 'fast_weights', None) module_num += 1 return module_num def _adapter_fast_weights(self): """Returns fast (task) weights from full model. """ for name, sub_module in self.named_modules(): if isinstance( sub_module, (Adapter_func, LayerNorm_func)) and sub_module.trainable: for param in sub_module.fast_weights: yield param def _adapter_fast_weights_bert(self): """Returns fast (task) weights from encoder. """ for name, sub_module in self.bert.named_modules(): if isinstance( sub_module, (Adapter_func, LayerNorm_func)) and sub_module.trainable: for param in sub_module.fast_weights: yield param def _adapter_fast_weights_dec(self): """Returns fast (task) weights from decoder. """ for name, sub_module in self.decoder.named_modules(): if isinstance( sub_module, (Adapter_func, LayerNorm_func)) and sub_module.trainable: for param in sub_module.fast_weights: yield param def _adapter_vars(self): """Returns true (meta) parameters from full model. """ for name, sub_module in self.named_modules(): if isinstance( sub_module, (Adapter_func, LayerNorm_func)) and sub_module.trainable: for param in sub_module.vars: yield param def _adapter_vars_bert(self): """Returns true (meta) parameters from encoder. """ for name, sub_module in self.bert.named_modules(): if isinstance( sub_module, (Adapter_func, LayerNorm_func)) and sub_module.trainable: for param in sub_module.vars: yield param def _adapter_vars_dec(self): """Returns true (meta) parameters from decoder. """ for name, sub_module in self.decoder.named_modules(): if isinstance( sub_module, (Adapter_func, LayerNorm_func)) and sub_module.trainable: for param in sub_module.vars: yield param
def __init__(self, args, device, vocab, checkpoint=None): super(RankAE, self).__init__() self.args = args self.device = device self.vocab = vocab self.vocab_size = len(vocab) self.beam_size = args.beam_size self.max_length = args.max_length self.min_length = args.min_length self.start_token = vocab['[unused1]'] self.end_token = vocab['[unused2]'] self.pad_token = vocab['[PAD]'] self.mask_token = vocab['[MASK]'] self.seg_token = vocab['[unused3]'] self.cls_token = vocab['[CLS]'] self.hidden_size = args.enc_hidden_size self.embeddings = nn.Embedding(self.vocab_size, self.hidden_size, padding_idx=0) if args.encoder == 'bert': self.encoder = Bert(args.bert_dir, args.finetune_bert) if(args.max_pos > 512): my_pos_embeddings = nn.Embedding(args.max_pos, self.encoder.model.config.hidden_size) my_pos_embeddings.weight.data[:512] = self.encoder.model.embeddings.position_embeddings.weight.data my_pos_embeddings.weight.data[512:] = self.encoder.model.embeddings.position_embeddings.weight.data[-1][None, :].repeat(args.max_pos-512, 1) self.encoder.model.embeddings.position_embeddings = my_pos_embeddings tgt_embeddings = nn.Embedding(self.vocab_size, self.encoder.model.config.hidden_size, padding_idx=0) else: self.encoder = TransformerEncoder(self.hidden_size, args.enc_ff_size, args.enc_heads, args.enc_dropout, args.enc_layers) tgt_embeddings = nn.Embedding(self.vocab_size, self.hidden_size, padding_idx=0) self.hier_encoder = TransformerEncoder(self.hidden_size, args.hier_ff_size, args.hier_heads, args.hier_dropout, args.hier_layers) self.cup_bilinear = nn.Bilinear(self.hidden_size, self.hidden_size, 1) self.pos_emb = PositionalEncoding(0., self.hidden_size) self.decoder = TransformerDecoder( self.args.dec_layers, self.args.dec_hidden_size, heads=self.args.dec_heads, d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings) self.generator = Generator(self.vocab_size, self.args.dec_hidden_size, self.pad_token) self.generator.linear.weight = self.decoder.embeddings.weight if checkpoint is not None: self.load_state_dict(checkpoint['model'], strict=True) else: if args.encoder == "transformer": for module in self.encoder.modules(): self._set_parameter_tf(module) xavier_uniform_(self.embeddings.weight) for module in self.decoder.modules(): self._set_parameter_tf(module) for module in self.hier_encoder.modules(): self._set_parameter_tf(module) for p in self.generator.parameters(): self._set_parameter_linear(p) for p in self.cup_bilinear.parameters(): self._set_parameter_linear(p) if args.share_emb: if args.encoder == 'bert': self.embeddings = self.encoder.model.embeddings.word_embeddings tgt_embeddings = nn.Embedding(self.vocab_size, self.encoder.model.config.hidden_size, padding_idx=0) tgt_embeddings.weight = copy.deepcopy(self.encoder.model.embeddings.word_embeddings.weight) else: tgt_embeddings = self.embeddings self.decoder.embeddings = tgt_embeddings self.generator.linear.weight = self.decoder.embeddings.weight self.to(device)
class AbsSummarizer(nn.Module): def __init__(self, args, device, checkpoint=None, bert_from_extractive=None): super(AbsSummarizer, self).__init__() self.args = args self.device = device self.bert = Bert(args.large, args.temp_dir, args.finetune_bert) if bert_from_extractive is not None: self.bert.model.load_state_dict(dict([ (n[11:], p) for n, p in bert_from_extractive.items() if n.startswith('bert.model') ]), strict=True) if (args.encoder == 'baseline'): bert_config = BertConfig( self.bert.model.config.vocab_size, hidden_size=args.enc_hidden_size, num_hidden_layers=args.enc_layers, num_attention_heads=8, intermediate_size=args.enc_ff_size, hidden_dropout_prob=args.enc_dropout, attention_probs_dropout_prob=args.enc_dropout) self.bert.model = BertModel(bert_config) if (args.max_pos > 512): my_pos_embeddings = nn.Embedding( args.max_pos, self.bert.model.config.hidden_size) my_pos_embeddings.weight.data[: 512] = self.bert.model.embeddings.position_embeddings.weight.data my_pos_embeddings.weight.data[ 512:] = self.bert.model.embeddings.position_embeddings.weight.data[ -1][None, :].repeat(args.max_pos - 512, 1) self.bert.model.embeddings.position_embeddings = my_pos_embeddings self.vocab_size = self.bert.model.config.vocab_size tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) if (self.args.share_emb): tgt_embeddings.weight = copy.deepcopy( self.bert.model.embeddings.word_embeddings.weight) self.decoder = TransformerDecoder(self.args.dec_layers, self.args.dec_hidden_size, heads=self.args.dec_heads, d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings) self.generator = get_generator(self.vocab_size, self.args.dec_hidden_size, device) self.generator[0].weight = self.decoder.embeddings.weight self.topical_output = None #get_topical_output(15, self.args.dec_hidden_size, device) # self.topical_output[0].weight = self.decoder.embeddings.weight if checkpoint is not None: self.load_state_dict(checkpoint['model'], strict=True) else: for module in self.decoder.modules(): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() for p in self.generator.parameters(): if p.dim() > 1: xavier_uniform_(p) else: p.data.zero_() if (args.use_bert_emb): tgt_embeddings = nn.Embedding( self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) tgt_embeddings.weight = copy.deepcopy( self.bert.model.embeddings.word_embeddings.weight) self.decoder.embeddings = tgt_embeddings self.generator[0].weight = self.decoder.embeddings.weight # self.topical_output[0].weight = self.decoder.embeddings.weight self.to(device) # def lda_process(self, batch): # result = np.zeros((len(batch), 512)) # # for i, b in enumerate(batch): # src_txt = tokenizer.convert_ids_to_tokens(b.tolist()) # src_txt = preprocess(' '.join(src_txt)) # # bow_vector = tm_dictionary.doc2bow(preprocess(' '.join(src_txt))) # # article_topic = sorted(lda_model[bow_vector], key=lambda tup: -1 * tup[1]) # [0] # # for index, value in article_topic[:1]: # result[i, index] = value # # return result def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls): top_vec = self.bert(src, segs, mask_src) # for i, b in enumerate(tgt): # tgt_txt = tokenizer.convert_ids_to_tokens(b.tolist()) # print(tgt_txt) # a = 1 + 2 # # add small normal distributed noise # noise = torch.normal(torch.zeros(top_vec.shape), torch.ones(top_vec.shape) / 2) # noise = noise.cuda() # top_vec += noise # if self.args.use_topic_modelling: # lda_res = self.lda_process(src) # # for i1 in range(len(lda_res)): # lda_res_tensor = torch.FloatTensor(lda_res[i1]) # for i2 in range(len(top_vec[i1])): # try: # top_vec[i1, i2] += lda_res_tensor.cuda() # except IndexError as err: # print(err) # print(top_vec.shape, lda_res.shape) # raise err dec_state = self.decoder.init_decoder_state(src, top_vec) decoder_outputs, state = self.decoder(tgt[:, :-1], top_vec, dec_state) # print('decoder', decoder_outputs.shape) return decoder_outputs, None
class AbsSummarizer(nn.Module): def __init__(self, args, device, checkpoint=None, bert_from_extractive=None): super(AbsSummarizer, self).__init__() self.args = args self.device = device self.bert = Bert(args.large, args.temp_dir, args.finetune_bert) if bert_from_extractive is not None: self.bert.model.load_state_dict( dict([(n[11:], p) for n, p in bert_from_extractive.items() if n.startswith('bert.model')]), strict=True) if (args.encoder == 'baseline'): bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.enc_hidden_size, num_hidden_layers=args.enc_layers, num_attention_heads=8, intermediate_size=args.enc_ff_size, hidden_dropout_prob=args.enc_dropout, attention_probs_dropout_prob=args.enc_dropout) self.bert.model = BertModel(bert_config) if (args.max_pos > 512): my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size) my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None, :].repeat(args.max_pos - 512, 1) self.bert.model.embeddings.position_embeddings = my_pos_embeddings self.vocab_size = self.bert.model.config.vocab_size self.enc_out_size = self.args.dec_hidden_size if self.args.use_dep: self.enc_out_size += 2 if self.args.use_frame: self.enc_frame = nn.Linear(1, 20) self.frame_attn = MultiHeadedAttention(1, 20, 0.1) self.enc_out_size += 20 self.enc_out = nn.Linear(self.enc_out_size, self.args.dec_hidden_size) self.drop = nn.Dropout(self.args.enc_dropout) self.layer_norm = nn.LayerNorm(self.args.dec_hidden_size, eps=1e-6) tgt_embeddings = nn.Embedding(self.vocab_size, self.args.dec_hidden_size, padding_idx=0) if (self.args.share_emb): tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight) self.decoder = TransformerDecoder( self.args.dec_layers, self.args.dec_hidden_size, heads=self.args.dec_heads, d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings) self.generator = get_generator(self.vocab_size, self.args.dec_hidden_size, device) self.generator[0].weight = self.decoder.embeddings.weight if checkpoint is not None: self.load_state_dict(checkpoint['model'], strict=True) else: for module in self.decoder.modules(): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() for p in self.generator.parameters(): if p.dim() > 1: xavier_uniform_(p) else: p.data.zero_() if (args.use_bert_emb): tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight) self.decoder.embeddings = tgt_embeddings self.generator[0].weight = self.decoder.embeddings.weight self.to(device) def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls, frame, dep): top_vec = self.bert(src, segs, mask_src) if self.args.use_dep: # dep_enmbeddings = self.enc_dep(dep[:, :, 0].unsqueeze(2).float()) top_vec = torch.cat((top_vec, dep.float()), dim=2) if self.args.use_frame: frame_embeddings = self.enc_frame(frame.float().unsqueeze(-1)) frame_embeddings = self.frame_attn(frame_embeddings, frame_embeddings, frame_embeddings, type="self") top_vec = torch.cat((top_vec, frame_embeddings), dim=2) top_vec = self.enc_out(top_vec) top_vec = self.layer_norm(top_vec) # top_vec = self.drop(top_vec) dec_state = self.decoder.init_decoder_state(src, top_vec) decoder_outputs, state = self.decoder(tgt[:, :-1], top_vec, dec_state) return decoder_outputs, None
class AbsSummarizer(nn.Module): def __init__(self, args, device, checkpoint=None): super(AbsSummarizer, self).__init__() self.args = args self.device = device self.bert = Bert(args.temp_dir, args.cased, args.finetune_bert) if (args.encoder == 'baseline'): bert_config = BertConfig( self.bert.model.config.vocab_size, hidden_size=args.enc_hidden_size, num_hidden_layers=args.enc_layers, num_attention_heads=12, intermediate_size=args.enc_ff_size, hidden_dropout_prob=args.enc_dropout, attention_probs_dropout_prob=args.enc_dropout) self.bert.model = BertModel(bert_config) if (args.max_pos > 512): my_pos_embeddings = nn.Embedding( args.max_pos, self.bert.model.config.hidden_size) my_pos_embeddings.weight.data[: 512] = self.bert.model.embeddings.position_embeddings.weight.data my_pos_embeddings.weight.data[ 512:] = self.bert.model.embeddings.position_embeddings.weight.data[ -1][None, :].repeat(args.max_pos - 512, 1) self.bert.model.embeddings.position_embeddings = my_pos_embeddings self.vocab_size = self.bert.model.config.vocab_size tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) if (self.args.share_emb): tgt_embeddings.weight = copy.deepcopy( self.bert.model.embeddings.word_embeddings.weight) self.decoder = TransformerDecoder(self.args.dec_layers, self.args.dec_hidden_size, heads=self.args.dec_heads, d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings) self.generator = get_generator(self.vocab_size, self.args.dec_hidden_size, device) self.generator[0].weight = self.decoder.embeddings.weight if checkpoint is not None: self.load_state_dict(checkpoint['model'], strict=True) else: for module in self.decoder.modules(): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() for p in self.generator.parameters(): if p.dim() > 1: xavier_uniform_(p) else: p.data.zero_() if (args.use_bert_emb): tgt_embeddings = nn.Embedding( self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) tgt_embeddings.weight = copy.deepcopy( self.bert.model.embeddings.word_embeddings.weight) self.decoder.embeddings = tgt_embeddings self.generator[0].weight = self.decoder.embeddings.weight self.to(device) def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls, likes=None): # print("attn: ",self.args.include_like_dist) top_vec = self.bert(src, segs, mask_src) # print("top_vec",top_vec.shape,top_vec.dtype) if self.args.include_like_dist and self.args.mode == "train": likes = torch.sqrt(likes.float()) max_likes = torch.max(likes, dim=1).values.float()[:, None] norm_likes = (likes / max_likes)[:, :, None] # print("norm_likes",norm_likes.shape, norm_likes.dtype) top_vec = top_vec * norm_likes # print("new top_vec",top_vec.shape, top_vec.dtype) # exit dec_state = self.decoder.init_decoder_state(src, top_vec) decoder_outputs, state = self.decoder( tgt[:, :-1], top_vec, dec_state) # <-- Pasar vector de grafo return decoder_outputs, None
class AbsSummarizer(nn.Module): def __init__(self, args, device, checkpoint=None, bert_from_extractive=None): super(AbsSummarizer, self).__init__() self.args = args self.device = device self.bert = Bert(args.large, args.temp_dir, args.finetune_bert, args.bart) if bert_from_extractive is not None: self.bert.model.load_state_dict( dict([(n[11:], p) for n, p in bert_from_extractive.items() if n.startswith('bert.model')]), strict=True) if (args.encoder == 'baseline'): bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.enc_hidden_size, num_hidden_layers=args.enc_layers, num_attention_heads=8, intermediate_size=args.enc_ff_size, hidden_dropout_prob=args.enc_dropout, attention_probs_dropout_prob=args.enc_dropout) self.bert.model = BertModel(bert_config) if(args.max_pos>512): my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size) my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-512,1) self.bert.model.embeddings.position_embeddings = my_pos_embeddings self.vocab_size = self.bert.model.config.vocab_size tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) if (self.args.share_emb): tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight) # set the multi_task decoder if self.args.multi_task: self.decoder_monolingual = TransformerDecoder( self.args.dec_layers, self.args.dec_hidden_size, heads=self.args.dec_heads, d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings, sep_dec=self.args.sep_decoder) # if not args.bart: self.decoder = TransformerDecoder( self.args.dec_layers, self.args.dec_hidden_size, heads=self.args.dec_heads, d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings, sep_dec=self.args.sep_decoder) self.generator = get_generator(self.vocab_size, self.args.dec_hidden_size, device) self.generator[0].weight = self.decoder.embeddings.weight # 先初始化,再读存档,避免出现错读。 for module in self.decoder.modules(): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() if self.args.multi_task: for module in self.decoder_monolingual.modules(): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() for p in self.generator.parameters(): if p.dim() > 1: xavier_uniform_(p) else: p.data.zero_() if(args.use_bert_emb): tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight) self.decoder.embeddings = tgt_embeddings self.generator[0].weight = self.decoder.embeddings.weight if checkpoint is not None: if args.few_shot and args.multi_task: # use one decoder to initialize two decoders new_states = OrderedDict() for each in checkpoint['model']: if each.startswith('decoder'): new_states[each] = copy.deepcopy(checkpoint['model'][each]) new_states[each.replace('decoder', 'decoder_monolingual')] = copy.deepcopy(checkpoint['model'][each]) else: new_states[each] = copy.deepcopy(checkpoint['model'][each]) self.load_state_dict(new_states, strict=True) else: self.load_state_dict(checkpoint['model'], strict=True) self.to(device) def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls, tgt_segs=None, tgt_eng=None): top_vec = self.bert(src, segs, mask_src) dec_state = self.decoder.init_decoder_state(src, top_vec) if self.args.multi_task: tgt_eng_segs = torch.ones(tgt_eng.size()).long().cuda() mono_dec_state = self.decoder_monolingual.init_decoder_state(src, top_vec) mono_decoder_outputs, mono_state = self.decoder_monolingual(tgt_eng[:, :-1], top_vec, mono_dec_state, tgt_segs = tgt_eng_segs[:, :-1]) else: mono_decoder_outputs = None mono_state = None if tgt_segs is None: decoder_outputs, state = self.decoder(tgt[:, :-1], top_vec, dec_state) else: decoder_outputs, state = self.decoder(tgt[:, :-1], top_vec, dec_state, tgt_segs=tgt_segs[:,:-1]) # print("decoder_outputs = ", decoder_outputs.size()) # print(decoder_outputs) # exit() return decoder_outputs, None, mono_decoder_outputs
def __init__(self, args, device, checkpoint=None): super(AbsSummarizer, self).__init__() self.args = args self.device = device self.bert = Bert(args.bert_model_path, args.large, args.temp_dir, args.finetune_bert) max_pos = args.max_pos if (max_pos > 512): my_pos_embeddings = nn.Embedding( max_pos, self.bert.model.config.hidden_size) my_pos_embeddings.weight.data[: 512] = self.bert.model.embeddings.position_embeddings.weight.data my_pos_embeddings.weight.data[ 512:] = self.bert.model.embeddings.position_embeddings.weight.data[ -1][None, :].repeat(max_pos - 512, 1) self.bert.model.embeddings.position_embeddings = my_pos_embeddings # guide-tags self.tag_embeddings = TiedEmbedding(args.max_n_tags, self.bert.model.config.hidden_size, padding_idx=0) self.tag_drop = nn.Dropout(args.tag_dropout) # decoder self.vocab_size = self.bert.model.config.vocab_size tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) self.decoder = TransformerDecoder(self.args.dec_layers, self.bert.model.config.hidden_size, heads=self.args.dec_heads, d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings, tag_embeddings=self.tag_embeddings) # generator self.generator = get_generator( args, self.vocab_size, self.bert.model.config.hidden_size, gen_weight=self.decoder.embeddings.weight) # load checkpoint or initialize the parameters if checkpoint is not None: self.load_state_dict(checkpoint['model'], strict=True) else: self.tag_embeddings.weight.data.normal_(mean=0.0, std=0.02) self.tag_embeddings.weight[ self.tag_embeddings.padding_idx].data.fill_(0) for module in self.decoder.modules(): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() for p in self.generator.parameters(): if p.dim() > 1: xavier_uniform_(p) else: p.data.zero_() if (args.use_bert_emb): self.decoder.embeddings.weight.data.copy_( self.bert.model.embeddings.word_embeddings.weight) self.to(device)
class AbsSummarizer(nn.Module): def __init__(self, args, device, checkpoint=None): super(AbsSummarizer, self).__init__() self.args = args self.device = device self.bert = Bert(args.bert_model_path, args.large, args.temp_dir, args.finetune_bert) max_pos = args.max_pos if (max_pos > 512): my_pos_embeddings = nn.Embedding( max_pos, self.bert.model.config.hidden_size) my_pos_embeddings.weight.data[: 512] = self.bert.model.embeddings.position_embeddings.weight.data my_pos_embeddings.weight.data[ 512:] = self.bert.model.embeddings.position_embeddings.weight.data[ -1][None, :].repeat(max_pos - 512, 1) self.bert.model.embeddings.position_embeddings = my_pos_embeddings # guide-tags self.tag_embeddings = TiedEmbedding(args.max_n_tags, self.bert.model.config.hidden_size, padding_idx=0) self.tag_drop = nn.Dropout(args.tag_dropout) # decoder self.vocab_size = self.bert.model.config.vocab_size tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) self.decoder = TransformerDecoder(self.args.dec_layers, self.bert.model.config.hidden_size, heads=self.args.dec_heads, d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings, tag_embeddings=self.tag_embeddings) # generator self.generator = get_generator( args, self.vocab_size, self.bert.model.config.hidden_size, gen_weight=self.decoder.embeddings.weight) # load checkpoint or initialize the parameters if checkpoint is not None: self.load_state_dict(checkpoint['model'], strict=True) else: self.tag_embeddings.weight.data.normal_(mean=0.0, std=0.02) self.tag_embeddings.weight[ self.tag_embeddings.padding_idx].data.fill_(0) for module in self.decoder.modules(): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() for p in self.generator.parameters(): if p.dim() > 1: xavier_uniform_(p) else: p.data.zero_() if (args.use_bert_emb): self.decoder.embeddings.weight.data.copy_( self.bert.model.embeddings.word_embeddings.weight) self.to(device) def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls, tag_src, tag_tgt): segs_src = (1 - segs % 2) * mask_src.long() top_vec = self.bert(src, segs_src, mask_src) if self.training and self.args.sent_dropout > 0: idx = (torch.arange(clss.size(1), device=clss.device) + 1).unsqueeze(0).expand_as(clss) # n x sents drop = torch.rand( clss.size(), dtype=torch.float, device=clss.device) < self.args.sent_dropout # n x sents idx = idx * drop.long() msk_drop = torch.sum( (segs.unsqueeze(-2) == idx.unsqueeze(-1)).float(), dim=1) # n x 512 msk_tag = (torch.sum(tag_src, dim=2) > 0).float() # n x 512 msk_drop = msk_drop * (1 - msk_tag) * mask_src.float() top_vec = top_vec * (1 - msk_drop).unsqueeze(-1) tag_vec = self.tag_embeddings.matmul(tag_src) top_vec = top_vec + self.tag_drop(tag_vec) dec_state = self.decoder.init_decoder_state(src, top_vec) if self.training and self.args.word_dropout > 0: word_mask = 103 drop = torch.rand(tgt.size(), dtype=torch.float, device=tgt.device) < self.args.word_dropout drop = drop * mask_tgt tgt = torch.where(drop, tgt.new_full(tgt.size(), word_mask), tgt) decoder_outputs, state = self.decoder(tgt[:, :-1], top_vec, dec_state, tag=tag_tgt[:, :-1]) return decoder_outputs, None
def __init__(self, args, device, checkpoint=None, bert_from_extractive=None): super(AbsSummarizer, self).__init__() self.args = args self.device = device self.bert = Bert(args.large, args.temp_dir, args.finetune_bert) #false, ../temp, ture #输入最多512个词(还要除掉[CLS]和[SEP]),最多两个句子合成一句。这之外的词和句子会没有对应的embedding,pooler是对cls位置编码 if bert_from_extractive is not None: self.bert.model.load_state_dict(dict([ (n[11:], p) for n, p in bert_from_extractive.items() if n.startswith('bert.model') ]), strict=True) if (args.encoder == 'baseline'): #default:bert bert_config = BertConfig( self.bert.model.config.vocab_size, hidden_size=args.enc_hidden_size, num_hidden_layers=args.enc_layers, num_attention_heads=8, intermediate_size=args.enc_ff_size, hidden_dropout_prob=args.enc_dropout, attention_probs_dropout_prob=args.enc_dropout) self.bert.model = BertModel(bert_config) if (args.max_pos > 512): #最大不大于512,故此层用不到 my_pos_embeddings = nn.Embedding( args.max_pos, self.bert.model.config.hidden_size) my_pos_embeddings.weight.data[: 512] = self.bert.model.embeddings.position_embeddings.weight.data my_pos_embeddings.weight.data[ 512:] = self.bert.model.embeddings.position_embeddings.weight.data[ -1][None, :].repeat(args.max_pos - 512, 1) self.bert.model.embeddings.position_embeddings = my_pos_embeddings self.vocab_size = self.bert.model.config.vocab_size #此为bert.model中config的vocab_size:21128 tgt_embeddings = nn.Embedding( self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) #同上hidden_size:768# #对摘要进行编码 if (self.args.share_emb): #False tgt_embeddings.weight = copy.deepcopy( self.bert.model.embeddings.word_embeddings.weight) #bertmodel可作为特征提取过程,既此时对应的encoder,transformer作为decoder self.decoder = TransformerDecoder( #多头机制 self.args.dec_layers, self.args.dec_hidden_size, heads=self.args.dec_heads, d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings) self.generator = get_generator(self.vocab_size, self.args.dec_hidden_size, device) self.generator[0].weight = self.decoder.embeddings.weight #21168 if checkpoint is not None: self.load_state_dict(checkpoint['model'], strict=True) else: #对模型进行训练 for module in self.decoder.modules(): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() for p in self.generator.parameters(): if p.dim() > 1: xavier_uniform_(p) else: p.data.zero_() if (args.use_bert_emb): tgt_embeddings = nn.Embedding( self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) tgt_embeddings.weight = copy.deepcopy( self.bert.model.embeddings.word_embeddings.weight) self.decoder.embeddings = tgt_embeddings self.generator[0].weight = self.decoder.embeddings.weight self.to(device)
class RankAE(nn.Module): def __init__(self, args, device, vocab, checkpoint=None): super(RankAE, self).__init__() self.args = args self.device = device self.vocab = vocab self.vocab_size = len(vocab) self.beam_size = args.beam_size self.max_length = args.max_length self.min_length = args.min_length self.start_token = vocab['[unused1]'] self.end_token = vocab['[unused2]'] self.pad_token = vocab['[PAD]'] self.mask_token = vocab['[MASK]'] self.seg_token = vocab['[unused3]'] self.cls_token = vocab['[CLS]'] self.hidden_size = args.enc_hidden_size self.embeddings = nn.Embedding(self.vocab_size, self.hidden_size, padding_idx=0) if args.encoder == 'bert': self.encoder = Bert(args.bert_dir, args.finetune_bert) if(args.max_pos > 512): my_pos_embeddings = nn.Embedding(args.max_pos, self.encoder.model.config.hidden_size) my_pos_embeddings.weight.data[:512] = self.encoder.model.embeddings.position_embeddings.weight.data my_pos_embeddings.weight.data[512:] = self.encoder.model.embeddings.position_embeddings.weight.data[-1][None, :].repeat(args.max_pos-512, 1) self.encoder.model.embeddings.position_embeddings = my_pos_embeddings tgt_embeddings = nn.Embedding(self.vocab_size, self.encoder.model.config.hidden_size, padding_idx=0) else: self.encoder = TransformerEncoder(self.hidden_size, args.enc_ff_size, args.enc_heads, args.enc_dropout, args.enc_layers) tgt_embeddings = nn.Embedding(self.vocab_size, self.hidden_size, padding_idx=0) self.hier_encoder = TransformerEncoder(self.hidden_size, args.hier_ff_size, args.hier_heads, args.hier_dropout, args.hier_layers) self.cup_bilinear = nn.Bilinear(self.hidden_size, self.hidden_size, 1) self.pos_emb = PositionalEncoding(0., self.hidden_size) self.decoder = TransformerDecoder( self.args.dec_layers, self.args.dec_hidden_size, heads=self.args.dec_heads, d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings) self.generator = Generator(self.vocab_size, self.args.dec_hidden_size, self.pad_token) self.generator.linear.weight = self.decoder.embeddings.weight if checkpoint is not None: self.load_state_dict(checkpoint['model'], strict=True) else: if args.encoder == "transformer": for module in self.encoder.modules(): self._set_parameter_tf(module) xavier_uniform_(self.embeddings.weight) for module in self.decoder.modules(): self._set_parameter_tf(module) for module in self.hier_encoder.modules(): self._set_parameter_tf(module) for p in self.generator.parameters(): self._set_parameter_linear(p) for p in self.cup_bilinear.parameters(): self._set_parameter_linear(p) if args.share_emb: if args.encoder == 'bert': self.embeddings = self.encoder.model.embeddings.word_embeddings tgt_embeddings = nn.Embedding(self.vocab_size, self.encoder.model.config.hidden_size, padding_idx=0) tgt_embeddings.weight = copy.deepcopy(self.encoder.model.embeddings.word_embeddings.weight) else: tgt_embeddings = self.embeddings self.decoder.embeddings = tgt_embeddings self.generator.linear.weight = self.decoder.embeddings.weight self.to(device) def _set_parameter_tf(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() def _set_parameter_linear(self, p): if p.dim() > 1: xavier_uniform_(p) else: p.data.zero_() def _rebuild_tgt(self, origin, index, sep_token=None): tgt_list = [torch.tensor([self.start_token], device=self.device)] selected = origin.index_select(0, index) for sent in selected: filted_sent = sent[sent != self.pad_token][1:] if sep_token is not None: filted_sent[-1] = sep_token else: filted_sent = filted_sent[:-1] tgt_list.append(filted_sent) new_tgt = torch.cat(tgt_list, 0) if sep_token is not None: new_tgt[-1] = self.end_token else: new_tgt = torch.cat([new_tgt, torch.tensor([self.end_token], device=self.device)], 0) return new_tgt def _build_memory_window(self, ex_segs, keep_clss, replace_clss=None, mask=None, samples=None): keep_cls_list = torch.split(keep_clss, ex_segs) window_list = [] for ex in keep_cls_list: ex_pad = F.pad(ex, (0, 0, self.args.win_size, self.args.win_size)).unsqueeze(1) ex_context = torch.cat([ex_pad[:ex.size(0)], ex.unsqueeze(1), ex_pad[self.args.win_size*2:]], 1) window_list.append(ex_context) memory = torch.cat(window_list, 0) if replace_clss is not None: replace_cls_list = torch.split(replace_clss, ex_segs) window_list = [] for ex in replace_cls_list: ex_pad = F.pad(ex, (0, 0, self.args.win_size, self.args.win_size)).unsqueeze(1) ex_context = torch.cat([ex_pad[:ex.size(0)], ex.unsqueeze(1), ex_pad[self.args.win_size*2:]], 1) window_list.append(ex_context) origin_memory = torch.cat(window_list, 0) sample_list = torch.split(samples, ex_segs) sample_tensor_list = [] for i in range(len(ex_segs)): sample_index_ = torch.randint(0, samples.size(-1), [mask.size(-1)], device=self.device) sample_index = torch.index_select(sample_list[i], 1, sample_index_) sample_tensor = replace_cls_list[i][sample_index] sample_tensor_list.append(sample_tensor) sample_memory = torch.cat(sample_tensor_list, 0) memory = memory * (mask == 2).unsqueeze(-1).float() + \ sample_memory * (mask == 0).unsqueeze(-1).float() + \ origin_memory * (mask == 1).unsqueeze(-1).float() return memory def _src_add_noise(self, sent, sampled_sent, expand_ratio=0.): role_emb = sent[1:2] filted_sent = sent[sent != self.pad_token][2:] # filted_sent = sent[sent != self.pad_token][1:] rand_size = sampled_sent.size(0) length = max(int(filted_sent.size(0)*(1+expand_ratio)), filted_sent.size(0)+1) while filted_sent.size(0) < length: target_length = length - filted_sent.size(0) rand_sent = sampled_sent[random.randint(0, rand_size-1)] rand_sent = rand_sent[rand_sent != self.pad_token][2:] # remove cls and role embedding # rand_sent = rand_sent[rand_sent != self.pad_token][1:] # no role embedding start_point = random.randint(0, rand_sent.size(0)-1) end_point = random.randint(start_point, rand_sent.size(0)) rand_segment = rand_sent[start_point:min(end_point, start_point+10, start_point+target_length)] insert_point = random.randint(0, filted_sent.size(0)-1) filted_sent = torch.cat([filted_sent[:insert_point], rand_segment, filted_sent[insert_point:]], 0) # return filted_sent return torch.cat([role_emb, filted_sent], 0) def _build_noised_src(self, src, ex_segs, samples, expand_ratio=0.): src_list = torch.split(src, ex_segs) new_src_list = [] sample_list = torch.split(samples, ex_segs) for i, ex in enumerate(src_list): for j, sent in enumerate(ex): sampled_sent = ex.index_select(0, sample_list[i][j]) expanded_sent = self._src_add_noise(sent, sampled_sent, expand_ratio) new_src = torch.cat([torch.tensor([self.cls_token], device=self.device), expanded_sent], 0) new_src_list.append(new_src) new_src = pad_sequence(new_src_list, batch_first=True, padding_value=self.pad_token) new_mask = new_src.data.ne(self.pad_token) new_segs = torch.zeros_like(new_src) return new_src, new_mask, new_segs def _build_context_tgt(self, tgt, ex_segs, win_size=1, modify=False, mask=None): tgt_list = torch.split(tgt, ex_segs) new_tgt_list = [] if modify and mask is not None: # 1 means keeping the sentence mask_list = torch.split(mask, ex_segs) for i in range(len(tgt_list)): sent_num = tgt_list[i].size(0) for j in range(sent_num): if modify: low = j-win_size up = j+win_size+1 index = torch.arange(low, up, device=self.device) index = index[mask_list[i][j] > 0] else: low = max(0, j-win_size) up = min(sent_num, j+win_size+1) index = torch.arange(low, up, device=self.device) new_tgt_list.append(self._rebuild_tgt(tgt_list[i], index, self.seg_token)) new_tgt = pad_sequence(new_tgt_list, batch_first=True, padding_value=self.pad_token) return new_tgt def _build_doc_tgt(self, tgt, vec, ex_segs, win_size=1, max_k=6, sigma=1.0): vec_list = torch.split(vec, ex_segs) tgt_list = torch.split(tgt, ex_segs) new_tgt_list = [] index_list = [] shift_list = [] accum_index = 0 for idx in range(len(ex_segs)): ex_vec = vec_list[idx] sent_num = ex_segs[idx] ex_tgt = tgt_list[idx] tgt_length = ex_tgt[:, 1:].ne(self.pad_token).sum(dim=1).float() topk_ids = self._centrality_rank(ex_vec, sent_num, tgt_length, win_size, max_k, sigma) new_tgt_list.append(self._rebuild_tgt(ex_tgt, topk_ids, self.seg_token)) shift_list.append(topk_ids) index_list.append(topk_ids + accum_index) accum_index += sent_num new_tgt = pad_sequence(new_tgt_list, batch_first=True, padding_value=self.pad_token) return new_tgt, index_list, shift_list def _centrality_rank(self, vec, sent_num, tgt_length, win_size, max_k, sigma, eta=0.5, min_length=5): assert vec.size(0) == sent_num sim = torch.sigmoid(self.cup_bilinear(vec.unsqueeze(1).expand(sent_num, sent_num, -1).contiguous(), vec.unsqueeze(0).expand(sent_num, sent_num, -1).contiguous()) ).squeeze().detach() # sim = torch.sigmoid(torch.mm(vec, vec.transpose(0, 1))) # sim = torch.cosine_similarity( # vec.unsqueeze(1).expand(sent_num, sent_num, -1).contiguous().view(sent_num * sent_num, -1), # vec.unsqueeze(0).expand(sent_num, sent_num, -1).contiguous().view(sent_num * sent_num, -1) # ).view(sent_num, sent_num).detach() # calculate sim weight k = min(max(sent_num // (win_size*2+1), 1), max_k) var = sent_num / k * 1. x = torch.arange(sent_num, device=self.device, dtype=torch.float).unsqueeze(0).expand_as(sim) u = torch.arange(sent_num, device=self.device, dtype=torch.float).unsqueeze(1) weight = torch.exp(-(x-u)**2 / (2. * var**2)) * (1. - torch.eye(sent_num, device=self.device)) # weight = 1. - torch.eye(sent_num, device=self.device) sim[tgt_length < min_length, :] = -1e20 # Calculate centrality and select top k sentence. topk_ids = torch.empty(0, dtype=torch.long, device=self.device) mask = torch.zeros([sent_num, sent_num], dtype=torch.float, device=self.device) for _ in range(k): mean_score = torch.sum(sim * weight, dim=1) / max(sent_num-1, 1) max_v, _ = torch.max(sim * weight * mask, dim=1) centrality = eta*mean_score - (1-eta)*max_v _, top_id = torch.topk(centrality, 1, dim=0, sorted=False) topk_ids = torch.cat([topk_ids, top_id], 0) sim[topk_ids, :] = -1e20 mask[:, topk_ids] = 1. topk_ids, _ = torch.sort(topk_ids) """ centrality = torch.sum(sim * weight, dim=1) _, topk_ids = torch.topk(centrality, k, dim=0, sorted=False) topk_ids, _ = torch.sort(topk_ids) """ return topk_ids def _add_mask(self, src, mask_src): pm_index = torch.empty_like(mask_src).float().uniform_().le(self.args.mask_token_prob) ps_index = torch.empty_like(mask_src[:, 0]).float().uniform_().gt(self.args.select_sent_prob) pm_index[ps_index] = 0 # Avoid mask [PAD] pm_index[(1-mask_src).byte()] = 0 # Avoid mask [CLS] pm_index[:, 0] = 0 # Avoid mask [SEG] pm_index[src == self.seg_token] = 0 src[pm_index] = self.mask_token return src def _build_cup(self, bsz, ex_segs, win_size=1, negative_num=2): cup = torch.split(torch.arange(0, bsz, dtype=torch.long, device=self.device), ex_segs) tgt = torch.split(torch.ones(bsz), ex_segs) cup_list = [] cup_origin_list = [] tgt_list = [] negative_list = [] for i in range(len(ex_segs)): sent_num = ex_segs[i] cup_low = cup[i][0].item() cup_up = cup[i][sent_num-1].item() cup_index = cup[i].repeat(win_size*2*(negative_num+1)) tgt_index = tgt[i].repeat(win_size*2*(negative_num+1)) cup_origin_list.append(cup[i].repeat(win_size*2*(negative_num+1))) tgt_index[sent_num*win_size*2:] = 0 for j in range(cup_index.size(0)): if tgt_index[j] == 1: cup_temp = cup_index[j] window_list = [t for t in range(max(cup_index[j]-win_size, cup_low), min(cup_index[j]+win_size, cup_up)+1) if t != cup_index[j]] cup_temp = window_list[(j // sent_num) % len(window_list)] else: cand_list = [t for t in range(cup_low, max(cup_index[j]-win_size, cup_low))] + \ [t for t in range(min(cup_index[j]+win_size, cup_up), cup_up)] cup_temp = cand_list[random.randint(0, len(cand_list)-1)] cup_index[j] = cup_temp negative_list.append((cup_index[sent_num*win_size*2:]-cup_low). view(negative_num*win_size*2, -1).transpose(0, 1)) cup_list.append(cup_index) tgt_list.append(tgt_index) tgt = torch.cat(tgt_list, dim=0).float().to(self.device) cup_origin = torch.cat(cup_origin_list, dim=0) cup = torch.cat(cup_list, dim=0) negative_sample = torch.cat(negative_list, dim=0) return cup, cup_origin, tgt[cup != -1], negative_sample def _build_option_window(self, bsz, ex_segs, win_size=1, keep_ratio=0.1, replace_ratio=0.2): assert keep_ratio + replace_ratio <= 1. noise_ratio = 1 - keep_ratio - replace_ratio window_size = 2*win_size+1 index = torch.split(torch.arange(1, bsz+1, dtype=torch.long, device=self.device), ex_segs) # 2 means noise addition, 1 means keep the memory, 0 means replacement tgt = torch.zeros([bsz, window_size], device=self.device, dtype=torch.int) prob = torch.empty([bsz, window_size], device=self.device).uniform_() tgt.masked_fill_(prob.lt(noise_ratio), 2) tgt.masked_fill_(prob.ge(1-keep_ratio), 1) tgt = torch.split(tgt, ex_segs) for i in range(len(ex_segs)): sent_num = ex_segs[i] index_pad = F.pad(index[i], (self.args.win_size, self.args.win_size)) for j in range(sent_num): window = index_pad[j:j+window_size] # Avoiding that all elements are 0 if torch.sum(tgt[i][j].byte()*(window > 0)) == 0: tgt[i][j][win_size] = 2 tgt[i][j][window == 0] = -1 tgt = torch.cat(tgt, 0) return tgt def _fast_translate_batch(self, batch, memory_bank, max_length, init_tokens=None, memory_mask=None, min_length=2, beam_size=3, ignore_mem_attn=False): batch_size = memory_bank.size(0) dec_states = self.decoder.init_decoder_state(batch.src, memory_bank, with_cache=True) # Tile states and memory beam_size times. dec_states.map_batch_fn( lambda state, dim: tile(state, beam_size, dim=dim)) memory_bank = tile(memory_bank, beam_size, dim=0) init_tokens = tile(init_tokens, beam_size, dim=0) memory_mask = tile(memory_mask, beam_size, dim=0) batch_offset = torch.arange( batch_size, dtype=torch.long, device=self.device) beam_offset = torch.arange( 0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=self.device) alive_seq = torch.full( [batch_size * beam_size, 1], self.start_token, dtype=torch.long, device=self.device) # Give full probability to the first beam on the first step. topk_log_probs = ( torch.tensor([0.0] + [float("-inf")] * (beam_size - 1), device=self.device).repeat(batch_size)) # Structure that holds finished hypotheses. hypotheses = [[] for _ in range(batch_size)] # noqa: F812 results = [[] for _ in range(batch_size)] # noqa: F812 for step in range(max_length): if step > 0: init_tokens = None # Decoder forward. decoder_input = alive_seq[:, -1].view(1, -1) decoder_input = decoder_input.transpose(0, 1) dec_out, dec_states, _ = self.decoder(decoder_input, memory_bank, dec_states, init_tokens, step=step, memory_masks=memory_mask, ignore_memory_attn=ignore_mem_attn) # Generator forward. log_probs = self.generator(dec_out.transpose(0, 1).squeeze(0)) vocab_size = log_probs.size(-1) if step < min_length: log_probs[:, self.end_token] = -1e20 if self.args.block_trigram: cur_len = alive_seq.size(1) if(cur_len > 3): for i in range(alive_seq.size(0)): fail = False words = [int(w) for w in alive_seq[i]] if(len(words) <= 3): continue trigrams = [(words[i-1], words[i], words[i+1]) for i in range(1, len(words)-1)] trigram = tuple(trigrams[-1]) if trigram in trigrams[:-1]: fail = True if fail: log_probs[i] = -1e20 # Multiply probs by the beam probability. log_probs += topk_log_probs.view(-1).unsqueeze(1) alpha = self.args.alpha length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha # Flatten probs into a list of possibilities. curr_scores = log_probs / length_penalty curr_scores = curr_scores.reshape(-1, beam_size * vocab_size) topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) # Recover log probs. topk_log_probs = topk_scores * length_penalty # Resolve beam origin and true word ids. topk_beam_index = topk_ids.div(vocab_size) topk_ids = topk_ids.fmod(vocab_size) # Map beam_index to batch_index in the flat representation. batch_index = ( topk_beam_index + beam_offset[:topk_beam_index.size(0)].unsqueeze(1)) select_indices = batch_index.view(-1) # Append last prediction. alive_seq = torch.cat( [alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1) is_finished = topk_ids.eq(self.end_token) if step + 1 == max_length: is_finished.fill_(1) # End condition is top beam is finished. end_condition = is_finished[:, 0].eq(1) # Save finished hypotheses. if is_finished.any(): predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1)) for i in range(is_finished.size(0)): b = batch_offset[i] if end_condition[i]: is_finished[i].fill_(1) finished_hyp = is_finished[i].nonzero().view(-1) # Store finished hypotheses for this batch. for j in finished_hyp: hypotheses[b].append(( topk_scores[i, j], predictions[i, j, 1:])) # If the batch reached the end, save the n_best hypotheses. if end_condition[i]: best_hyp = sorted( hypotheses[b], key=lambda x: x[0], reverse=True) _, pred = best_hyp[0] results[b].append(pred) non_finished = end_condition.eq(0).nonzero().view(-1) # If all sentences are translated, no need to go further. if len(non_finished) == 0: break # Remove finished batches for the next step. topk_log_probs = topk_log_probs.index_select(0, non_finished) batch_index = batch_index.index_select(0, non_finished) batch_offset = batch_offset.index_select(0, non_finished) alive_seq = predictions.index_select(0, non_finished) \ .view(-1, alive_seq.size(-1)) # Reorder states. select_indices = batch_index.view(-1) if memory_bank is not None: memory_bank = memory_bank.index_select(0, select_indices) if memory_mask is not None: memory_mask = memory_mask.index_select(0, select_indices) if init_tokens is not None: init_tokens = init_tokens.index_select(0, select_indices) dec_states.map_batch_fn( lambda state, dim: state.index_select(dim, select_indices)) results = [t[0] for t in results] return results def forward(self, batch): src = batch.src tgt = batch.tgt segs = batch.segs mask_src = batch.mask_src ex_segs = batch.ex_segs if self.training: # Sample some dialogue utterances to do auto-encoder ex_size = batch.src.size(0) ex_index = [i for i in range(ex_size)] random.shuffle(ex_index) ex_indexs = torch.tensor(ex_index, dtype=torch.long, device=self.device) ex_sample_indexs = ex_indexs[:max(int(ex_size * self.args.sample_ratio), 1)] # Get Context utterance training samples and targets cup_index, cup_original_index, cup_tgt, negative_samples = \ self._build_cup(src.size(0), ex_segs, self.args.win_size, self.args.negative_sample_num) setattr(batch, 'cup_tgt', cup_tgt) option_mask = self._build_option_window(src.size(0), ex_segs, win_size=self.args.win_size, keep_ratio=self.args.ps if self.training else 1., replace_ratio=self.args.pr if self.training else 0.) if self.training: # Build noised src noised_src, noised_src_mask, noised_src_segs = \ self._build_noised_src(src, ex_segs, samples=negative_samples, expand_ratio=self.args.expand_ratio) # build context tgt context_tgt = self._build_context_tgt(tgt, ex_segs, self.args.win_size, modify=self.training, mask=option_mask) setattr(batch, 'context_tgt', context_tgt) # DAE: Randomly mask tokens if self.training: src = self._add_mask(src.clone(), mask_src) noised_src = self._add_mask(noised_src, noised_src_mask) if self.args.encoder == "bert": top_vec = self.encoder(src, segs, mask_src) else: src_emb = self.embeddings(src) top_vec = self.encoder(src_emb, 1-mask_src) clss = top_vec[:, 0, :] # Hierarchical encoder cls_list = torch.split(clss, ex_segs) cls_input = nn.utils.rnn.pad_sequence(cls_list, batch_first=True, padding_value=0.) cls_mask_list = [mask_src.new_zeros([length]) for length in ex_segs] cls_mask = nn.utils.rnn.pad_sequence(cls_mask_list, batch_first=True, padding_value=1) hier = self.hier_encoder(cls_input, cls_mask) hier = hier.view(-1, hier.size(-1))[(1-cls_mask.view(-1)).byte()] if self.training: # calculate cup score cup_tensor = torch.index_select(clss, 0, cup_index) origin_tensor = torch.index_select(clss, 0, cup_original_index) cup_score = torch.sigmoid(self.cup_bilinear(origin_tensor, cup_tensor)).squeeze() # cup_score = torch.sigmoid(origin_tensor.unsqueeze(1).bmm(cup_tensor.unsqueeze(-1)).squeeze()) # noised src encode if self.args.encoder == "bert": noised_top_vec = self.encoder(noised_src, noised_src_segs, noised_src_mask) else: noised_src_emb = self.embeddings(noised_src) noised_top_vec = self.encoder(noised_src_emb, 1-noised_src_mask) noised_clss = noised_top_vec[:, 0, :] noised_cls_mem = self._build_memory_window(ex_segs, noised_clss, clss, option_mask, negative_samples) noised_cls_mem = self.pos_emb(noised_cls_mem) # sample training examples context_tgt_sample = torch.index_select(context_tgt, 0, ex_sample_indexs) noised_cls_mem_sample = torch.index_select(noised_cls_mem, 0, ex_sample_indexs) hier_sample = torch.index_select(hier, 0, ex_sample_indexs) else: cup_score = None if self.training: dec_state = self.decoder.init_decoder_state(noised_src, noised_cls_mem_sample) decode_context, _, _ = self.decoder(context_tgt_sample[:, :-1], noised_cls_mem_sample, dec_state, init_tokens=hier_sample) doc_data = None # For loss computation. if ex_sample_indexs is not None: batch.context_tgt = context_tgt_sample else: decode_context = None # Build paragraph tgt based on centrality rank. doc_tgt, doc_index, _ = self._build_doc_tgt(tgt, clss, ex_segs, self.args.win_size, self.args.ranking_max_k) centrality_segs = [len(iex) for iex in doc_index] centrality_index = [sum(centrality_segs[:i]) for i in range(len(centrality_segs)+1)] doc_index = torch.cat(doc_index, 0) setattr(batch, 'doc_tgt', doc_tgt) doc_hier_sample = torch.index_select(hier, 0, doc_index) # original cls mem cls_mem = self._build_memory_window(ex_segs, clss) cls_mem = self.pos_emb(cls_mem) doc_cls_mem = torch.index_select(cls_mem, 0, doc_index) # Context aware doc target context_doc_tgt = torch.index_select(context_tgt, 0, doc_index) setattr(batch, 'context_doc_tgt', context_doc_tgt) setattr(batch, 'doc_segs', centrality_index) doc_context_long = self._fast_translate_batch(batch, doc_cls_mem, self.max_length, init_tokens=doc_hier_sample, min_length=2, beam_size=self.beam_size) doc_context_long = [torch.cat(doc_context_long[centrality_index[i]:centrality_index[i+1]], 0) for i in range(len(centrality_segs))] doc_data = doc_context_long return cup_score, decode_context, doc_data
class AbsSummarizer(nn.Module): def __init__(self, args, device, checkpoint=None, bert_from_extractive=None): super(AbsSummarizer, self).__init__() self.args = args self.device = device self.bert = Bert(args.model_path, args.large, args.temp_dir, args.finetune_bert) if bert_from_extractive is not None: self.bert.model.load_state_dict( dict( [ (n[11:], p) for n, p in bert_from_extractive.items() if n.startswith("bert.model") ] ), strict=True, ) if args.encoder == "baseline": bert_config = BertConfig( self.bert.model.config.vocab_size, hidden_size=args.enc_hidden_size, num_hidden_layers=args.enc_layers, num_attention_heads=8, intermediate_size=args.enc_ff_size, hidden_dropout_prob=args.enc_dropout, attention_probs_dropout_prob=args.enc_dropout, ) self.bert.model = BertModel(bert_config) if args.max_pos > 512: my_pos_embeddings = nn.Embedding( args.max_pos, self.bert.model.config.hidden_size ) my_pos_embeddings.weight.data[ :512 ] = self.bert.model.embeddings.position_embeddings.weight.data my_pos_embeddings.weight.data[ 512: ] = self.bert.model.embeddings.position_embeddings.weight.data[-1][ None, : ].repeat( args.max_pos - 512, 1 ) self.bert.model.embeddings.position_embeddings = my_pos_embeddings self.vocab_size = self.bert.model.config.vocab_size tgt_embeddings = nn.Embedding( self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0 ) if self.args.share_emb: tgt_embeddings.weight = copy.deepcopy( self.bert.model.embeddings.word_embeddings.weight ) self.decoder = TransformerDecoder( self.args.dec_layers, self.args.dec_hidden_size, heads=self.args.dec_heads, d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings, ) self.generator = get_generator( self.vocab_size, self.args.dec_hidden_size, device ) self.generator[0].weight = self.decoder.embeddings.weight if checkpoint is not None: self.load_state_dict(checkpoint["model"], strict=True) else: for module in self.decoder.modules(): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() for p in self.generator.parameters(): if p.dim() > 1: xavier_uniform_(p) else: p.data.zero_() if args.use_bert_emb: tgt_embeddings = nn.Embedding( self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0 ) tgt_embeddings.weight = copy.deepcopy( self.bert.model.embeddings.word_embeddings.weight ) self.decoder.embeddings = tgt_embeddings self.generator[0].weight = self.decoder.embeddings.weight self.to(device) def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls): top_vec = self.bert(src, segs, mask_src) for i in range(1, top_vec.shape[1]): top_vec[0][i] = torch.zeros(top_vec.shape[2]) dec_state = self.decoder.init_decoder_state(src, top_vec) decoder_outputs, state = self.decoder(tgt[:, :-1], top_vec, dec_state) return decoder_outputs, None
class AbsSummarizer(nn.Module): def __init__(self, args, device, checkpoint=None, bert_from_extractive=None): super(AbsSummarizer, self).__init__() self.args = args self.device = device self.bert = Bert(args.large, args.temp_dir, args.finetune_bert) if bert_from_extractive is not None: self.bert.model.load_state_dict(dict([ (n[11:], p) for n, p in bert_from_extractive.items() if n.startswith('bert.model') ]), strict=True) if (args.encoder == 'baseline'): bert_config = BertConfig( self.bert.model.config.vocab_size, hidden_size=args.enc_hidden_size, num_hidden_layers=args.enc_layers, num_attention_heads=8, intermediate_size=args.enc_ff_size, hidden_dropout_prob=args.enc_dropout, attention_probs_dropout_prob=args.enc_dropout) self.bert.model = BertModel(bert_config) self.graph_encoder = graph_encoder(args, self.bert.model.embeddings) if (args.max_pos > 512): my_pos_embeddings = nn.Embedding( args.max_pos, self.bert.model.config.hidden_size) my_pos_embeddings.weight.data[: 512] = self.bert.model.embeddings.position_embeddings.weight.data my_pos_embeddings.weight.data[ 512:] = self.bert.model.embeddings.position_embeddings.weight.data[ -1][None, :].repeat(args.max_pos - 512, 1) self.bert.model.embeddings.position_embeddings = my_pos_embeddings self.vocab_size = self.bert.model.config.vocab_size tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) if (self.args.share_emb): tgt_embeddings.weight = copy.deepcopy( self.bert.model.embeddings.word_embeddings.weight) self.decoder = TransformerDecoder(self.args.dec_layers, self.args.dec_hidden_size, heads=self.args.dec_heads, d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings) # for name, param in self.decoder.named_parameters(): if name == 'fix_top': xavier_uniform_(param) # self.generator = get_generator(self.vocab_size, self.args.dec_hidden_size, device, args.copy) self.generator.voc_gen[0].weight = self.decoder.embeddings.weight if checkpoint is not None: self.load_state_dict(checkpoint['model'], strict=True) else: for module in self.decoder.modules(): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() for p in self.generator.parameters(): if p.dim() > 1: xavier_uniform_(p) else: p.data.zero_() if (args.use_bert_emb): tgt_embeddings = nn.Embedding( self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) tgt_embeddings.weight = copy.deepcopy( self.bert.model.embeddings.word_embeddings.weight) self.decoder.embeddings = tgt_embeddings self.generator.voc_gen[ 0].weight = self.decoder.embeddings.weight self.copy = args.copy self.to(device) def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls, batch=None): # gents, emask = self.graph_encoder(batch, self.bert.model.embeddings) # top_vec = self.bert(src, segs, mask_src) ent_top_vec = None if self.copy == True: ent_top_vec = self.bert(batch.ent_src, batch.ent_seg_ids, batch.mask_ent_src) # sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), clss] # sents_vec = sents_vec * mask_cls[:, :, None].float() dec_state = self.decoder.init_decoder_state(src, top_vec) decoder_outputs, state, src_context, graph_context = self.decoder( tgt[:, :-1], top_vec, dec_state, gents=gents, emask=emask) return decoder_outputs, None, src_context, graph_context, top_vec, ent_top_vec, emask