def train_model(epochs=10, num_gradients_accumulation=4, batch_size=4, gpu_id=0, lr=1e-5, load_dir='/content/BERT checkpoints'): # make sure your model is on GPU device = torch.device(f"cuda:{gpu_id}") # ------------------------LOAD MODEL----------------- print('load the model....') bert_encoder = BertConfig.from_pretrained('bert-base-uncased') bert_decoder = BertConfig.from_pretrained('bert-base-uncased', is_decoder=True) config = EncoderDecoderConfig.from_encoder_decoder_configs( bert_encoder, bert_decoder) model = EncoderDecoderModel(config) model = model.to(device) print('load success') # ------------------------END LOAD MODEL-------------- # ------------------------LOAD TRAIN DATA------------------ train_data = torch.load("/content/train_data.pth") train_dataset = TensorDataset(*train_data) train_dataloader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=batch_size) val_data = torch.load("/content/validate_data.pth") val_dataset = TensorDataset(*val_data) val_dataloader = DataLoader(dataset=val_dataset, shuffle=True, batch_size=batch_size) # ------------------------END LOAD TRAIN DATA-------------- # ------------------------SET OPTIMIZER------------------- num_train_optimization_steps = len( train_dataset) * epochs // batch_size // num_gradients_accumulation param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = AdamW( optimizer_grouped_parameters, lr=lr, weight_decay=0.01, ) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=num_train_optimization_steps // 10, num_training_steps=num_train_optimization_steps) # ------------------------START TRAINING------------------- update_count = 0 start = time.time() print('start training....') for epoch in range(epochs): # ------------------------training------------------------ model.train() losses = 0 times = 0 print('\n' + '-' * 20 + f'epoch {epoch}' + '-' * 20) for batch in tqdm(train_dataloader): batch = [item.to(device) for item in batch] encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch logits = model(input_ids=encoder_input, attention_mask=mask_encoder_input, decoder_input_ids=decoder_input, decoder_attention_mask=mask_decoder_input) out = logits[0][:, :-1].contiguous() target = decoder_input[:, 1:].contiguous() target_mask = mask_decoder_input[:, 1:].contiguous() loss = util.sequence_cross_entropy_with_logits(out, target, target_mask, average="token") loss.backward() losses += loss.item() times += 1 update_count += 1 if update_count % num_gradients_accumulation == num_gradients_accumulation - 1: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() scheduler.step() optimizer.zero_grad() end = time.time() print(f'time: {(end - start)}') print(f'loss: {losses / times}') start = end # ------------------------validate------------------------ model.eval() perplexity = 0 batch_count = 0 print('\nstart calculate the perplexity....') with torch.no_grad(): for batch in tqdm(val_dataloader): batch = [item.to(device) for item in batch] encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch logits = model(input_ids=encoder_input, attention_mask=mask_encoder_input, decoder_input_ids=decoder_input, decoder_attention_mask=mask_decoder_input) out = logits[0][:, :-1].contiguous() target = decoder_input[:, 1:].contiguous() target_mask = mask_decoder_input[:, 1:].contiguous() # print(out.shape,target.shape,target_mask.shape) loss = util.sequence_cross_entropy_with_logits(out, target, target_mask, average="token") perplexity += np.exp(loss.item()) batch_count += 1 print(f'\nvalidate perplexity: {perplexity / batch_count}') torch.save( model.state_dict(), os.path.join(os.path.abspath('.'), load_dir, "model-" + str(epoch) + ".pth"))
model.encoder = BertModel.from_pretrained('bert-base-uncased') if args.gpu and torch.cuda.is_available(): model = model.cuda() loss_fun = nn.CrossEntropyLoss() #loss_fun = nn.CrossEntropyLoss(ignore_index=0) optimizer = torch.optim.Adam(model.decoder.parameters(), lr=args.lr) # 记录时间 begin_time = datetime.now() print("Start training BERT: ", begin_time) # 开始训练 for epoch in range(args.epoch): model.train() train_loss_list = [] train_acc_list = [] for tasks, plans in tqdm(train_dataloader): tokenized_text = encoder_tokenizer(tasks, padding=True, truncation=True, max_length=100, return_tensors="pt").input_ids targets = torch.LongTensor(decoder_tokenizer.tokenize(plans)) if args.gpu and torch.cuda.is_available(): tokenized_text = tokenized_text.to("cuda") targets = targets.to("cuda") outputs = model(input_ids = tokenized_text, decoder_input_ids = targets[:,:-1]) output_dim = outputs.logits.shape[-1] outputs.logits = outputs.logits.contiguous().view(-1, output_dim) labels = targets[:,1:].contiguous().view(-1)