def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None if self.use_input_mask: input_mask = random_attention_mask([self.batch_size, self.seq_length]) sequence_labels = None token_labels = None choice_labels = None if self.use_labels: sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) choice_labels = ids_tensor([self.batch_size], self.num_choices) config = MPNetConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, intermediate_size=self.intermediate_size, hidden_act=self.hidden_act, hidden_dropout_prob=self.hidden_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, initializer_range=self.initializer_range, ) return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
def main(train_epoch, batch_size, seq_length, lr, corpus_path, vocab_path, config_path, pretrain_model_path, output_record_path, model_save_path): seed_everything(997) num_train_epochs = train_epoch pretrain_batch_size = batch_size seq_length = seq_length lr = lr corpus_path = corpus_path vocab_path = vocab_path config_path = config_path output_record_path = output_record_path model_save_path = model_save_path tokenizer = BertTokenizer.from_pretrained(vocab_path) data = read_data(corpus_path, tokenizer) train_dataset = OppoDataset(data) config = MPNetConfig.from_pretrained( pretrained_model_name_or_path=config_path) if os.path.exists('./mpnet_model/pytorch_model.bin'): model = MPNetForMaskedLM.from_pretrained(pretrain_model_path, config=config) else: model = MPNetForMaskedLM(config=config) model.resize_token_embeddings(883) data_collator = Collator(max_seq_len=seq_length, tokenizer=tokenizer, mlm_probability=0.25) training_args = TrainingArguments( output_dir=output_record_path, overwrite_output_dir=True, num_train_epochs=num_train_epochs, learning_rate=lr, dataloader_num_workers=0, prediction_loss_only=True, per_device_train_batch_size=pretrain_batch_size, save_strategy='no', # save_steps=4000, # save_total_limit=20, seed=1080, label_smoothing_factor=0.001) trainer = Trainer(model=model, args=training_args, data_collator=data_collator, train_dataset=train_dataset) trainer.train() trainer.save_model(model_save_path)
def get_config(self): return MPNetConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, intermediate_size=self.intermediate_size, hidden_act=self.hidden_act, hidden_dropout_prob=self.hidden_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, initializer_range=self.initializer_range, )
train=pd.read_csv('../tcdata/train.csv',header=None) # test=pd.read_csv('../tcdata/track1_round1_testB.csv',header=None) test=pd.read_csv('../tcdata/testB.csv',header=None) model_path='../model_weight/mpnet/' output_model='../tmp/MPNet.pth' batch_size=32 # 合并训练集与测试集 制作特征 for i in range(1,3): train[i]=train[i].apply(lambda x:x.replace('|','').strip()) for i in range(1,2): test[i]=test[i].apply(lambda x:x.replace('|','').strip()) train.columns=['idx','sentence','label1','label2'] test.columns=['idx','sentence'] tokenizer=BertTokenizerFast.from_pretrained(model_path) config=MPNetConfig.from_pretrained(model_path,num_labels=17,hidden_dropout_prob=0.2) # config.output_attentions=True # In[11]: def train_model(train_df,val_df,test_oof): ###-------------------- early_stop=0 print("Reading training data...") train_set = CustomDataset(train_df, maxlen=128,tokenizer=tokenizer) train_loader = Data.DataLoader(train_set, batch_size=batch_size, num_workers=5, shuffle=True) print("Reading validation data...") val_set = CustomDataset(val_df, maxlen=128, tokenizer=tokenizer)
def get_large_model_config(self): return MPNetConfig.from_pretrained("microsoft/mpnet-base")