def main(args): tokenizer = BartTokenizer.from_pretrained(args.tokenizer_path) bart_config = BartConfig() bart_config.vocab_size = len(tokenizer) bart_config.eos_token_id = tokenizer.eos_token_id bart_config.bos_token_id = tokenizer.bos_token_id bart_config.pad_token_id = tokenizer.pad_token_id bart_config.save_pretrained(args.config_path)
def build_graph(self): """构建模型""" if self.config.bart_pre_training: self.model = CustomBartGeneration.from_pretrained( self.config.bart_pre_training) if self.model.config.vocab_size != self.config.vocab_size: # 使用预训练模型时词汇表发生变化, 重置embedding表的大小 self.model.resize_token_embeddings(self.config.vocab_size) else: bart_config = BartConfig() bart_config.activation_function = self.config.activate_func bart_config.vocab_size = self.config.vocab_size bart_config.d_model = self.config.embed_size bart_config.max_position_embeddings = self.config.embed_size bart_config.max_length = self.config.max_generate_length bart_config.num_labels = self.config.num_labels bart_config.image_para_freeze = self.config.image_para_freeze bart_config.encoder_layers = self.config.n_layers bart_config.decoder_layers = self.config.n_layers bart_config.encoder_attention_heads = self.config.n_head bart_config.decoder_attention_heads = self.config.n_head bart_config.encoder_ffn_dim = self.config.ffn_dim bart_config.decoder_ffn_dim = self.config.ffn_dim bart_config.pad_token_id = PAD_ID bart_config.bos_token_id = BOS_ID bart_config.eos_token_id = EOS_ID self.model = CustomBartGeneration(config=bart_config) # multi-task # bart_config.summary_use_proj = True # bart_config.summary_activation = None # bart_config.summary_first_dropout = True # bart_config.summary_proj_to_labels = 0.1 # bart_config.summary_type = "cls_index" # self.model = CustomBartGenerationDoubleHeads(config=bart_config) if torch.cuda.is_available(): self.model.to(self.config.device) if self.config.checkpoint: self.checkpoint_dict = self.load_model(self.config.checkpoint) if self.is_train: no_decay = ['bias', 'layer_norm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': self.config.weight_decay }, { 'params': [ p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.config.learning_rate, eps=self.config.adam_epsilon) self.scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=self.config.num_warmup_steps, num_training_steps=self.config.num_training_steps) if self.config.checkpoint and self.checkpoint_dict: self.optimizer.load_state_dict( self.checkpoint_dict["optimizer"]) # 加载优化器参数 self.scheduler.load_state_dict( self.checkpoint_dict["lr_scheduler"]) # 加载lr_scheduler