def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict): encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) pt_model = EncoderDecoderModel(encoder_decoder_config) fx_model = FlaxEncoderDecoderModel(encoder_decoder_config) fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_model.params = fx_state self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
def train_model(epochs=10, num_gradients_accumulation=4, batch_size=4, gpu_id=0, lr=1e-5, load_dir='/content/BERT checkpoints'): # make sure your model is on GPU device = torch.device(f"cuda:{gpu_id}") # ------------------------LOAD MODEL----------------- print('load the model....') bert_encoder = BertConfig.from_pretrained('bert-base-uncased') bert_decoder = BertConfig.from_pretrained('bert-base-uncased', is_decoder=True) config = EncoderDecoderConfig.from_encoder_decoder_configs( bert_encoder, bert_decoder) model = EncoderDecoderModel(config) model = model.to(device) print('load success') # ------------------------END LOAD MODEL-------------- # ------------------------LOAD TRAIN DATA------------------ train_data = torch.load("/content/train_data.pth") train_dataset = TensorDataset(*train_data) train_dataloader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=batch_size) val_data = torch.load("/content/validate_data.pth") val_dataset = TensorDataset(*val_data) val_dataloader = DataLoader(dataset=val_dataset, shuffle=True, batch_size=batch_size) # ------------------------END LOAD TRAIN DATA-------------- # ------------------------SET OPTIMIZER------------------- num_train_optimization_steps = len( train_dataset) * epochs // batch_size // num_gradients_accumulation param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = AdamW( optimizer_grouped_parameters, lr=lr, weight_decay=0.01, ) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=num_train_optimization_steps // 10, num_training_steps=num_train_optimization_steps) # ------------------------START TRAINING------------------- update_count = 0 start = time.time() print('start training....') for epoch in range(epochs): # ------------------------training------------------------ model.train() losses = 0 times = 0 print('\n' + '-' * 20 + f'epoch {epoch}' + '-' * 20) for batch in tqdm(train_dataloader): batch = [item.to(device) for item in batch] encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch logits = model(input_ids=encoder_input, attention_mask=mask_encoder_input, decoder_input_ids=decoder_input, decoder_attention_mask=mask_decoder_input) out = logits[0][:, :-1].contiguous() target = decoder_input[:, 1:].contiguous() target_mask = mask_decoder_input[:, 1:].contiguous() loss = util.sequence_cross_entropy_with_logits(out, target, target_mask, average="token") loss.backward() losses += loss.item() times += 1 update_count += 1 if update_count % num_gradients_accumulation == num_gradients_accumulation - 1: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() scheduler.step() optimizer.zero_grad() end = time.time() print(f'time: {(end - start)}') print(f'loss: {losses / times}') start = end # ------------------------validate------------------------ model.eval() perplexity = 0 batch_count = 0 print('\nstart calculate the perplexity....') with torch.no_grad(): for batch in tqdm(val_dataloader): batch = [item.to(device) for item in batch] encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch logits = model(input_ids=encoder_input, attention_mask=mask_encoder_input, decoder_input_ids=decoder_input, decoder_attention_mask=mask_decoder_input) out = logits[0][:, :-1].contiguous() target = decoder_input[:, 1:].contiguous() target_mask = mask_decoder_input[:, 1:].contiguous() # print(out.shape,target.shape,target_mask.shape) loss = util.sequence_cross_entropy_with_logits(out, target, target_mask, average="token") perplexity += np.exp(loss.item()) batch_count += 1 print(f'\nvalidate perplexity: {perplexity / batch_count}') torch.save( model.state_dict(), os.path.join(os.path.abspath('.'), load_dir, "model-" + str(epoch) + ".pth"))
print("epoch: {} train_loss: {:.3f}, train_acc: {:.3f}".format(epoch, mean_train_loss, mean_train_acc)) if epoch % 5 == 0: model.eval() valid_all_match = [] for tasks, plans in tqdm(valid_seen_dataloader): try: tokenized_text = encoder_tokenizer(tasks, padding=True, truncation=True, max_length=100, return_tensors="pt").input_ids if args.gpu and torch.cuda.is_available(): tokenized_text = tokenized_text.to("cuda") output_labels = model.generate(tokenized_text, decoder_start_token_id=1) ouput_array = output_labels.cpu().numpy()[0] targets = decoder_tokenizer.tokenize(plans) all_match = np.all(ouput_array[:len(targets[0])] == targets[0]) valid_all_match.append(1 if all_match else 0) except: valid_all_match.append(0) print("epoch: {} valid_all_match: {:.3f}".format(epoch, np.mean(valid_all_match))) end_time = datetime.now() print("Finished Training:", end_time) print("Time elapsed (in hour): {:.2f}".format((end_time - begin_time).total_seconds() / 3600)) # 保存模型 torch.save(model.state_dict(), "exp/{}.pth".format(end_time)) config.save_pretrained("exp/{}".format(end_time))