def get_optimizer_and_schedule(args, model: EncoderDecoderModel): # 预训练参数和初始化参数使用不同的学习率 if args.ngpu > 1: model = model.module init_params_id = [] for layer in model.decoder.transformer.h: init_params_id.extend(list(map(id, layer.crossattention.parameters()))) init_params_id.extend(list(map(id, layer.ln_cross_attn.parameters()))) pretrained_params = filter( lambda p: id(p) not in init_params_id, model.parameters() ) initialized_params = filter(lambda p: id(p) in init_params_id, model.parameters()) params_setting = [ {"params": initialized_params}, {"params": pretrained_params, "lr": args.finetune_lr}, ] optimizer = optim.Adam(params_setting, lr=args.lr) schedule = get_cosine_schedule_with_warmup( optimizer, num_training_steps=args.num_training_steps, num_warmup_steps=args.num_warmup_steps, ) return optimizer, schedule
def create_and_check_encoder_decoder_shared_weights( self, config, input_ids, attention_mask, encoder_hidden_states, decoder_config, decoder_input_ids, decoder_attention_mask, labels, **kwargs): torch.manual_seed(0) encoder_model, decoder_model = self.get_encoder_decoder_model( config, decoder_config) model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) model.to(torch_device) model.eval() # load state dict copies weights but does not tie them decoder_state_dict = model.decoder._modules[ model.decoder.base_model_prefix].state_dict() model.encoder.load_state_dict(decoder_state_dict, strict=False) torch.manual_seed(0) tied_encoder_model, tied_decoder_model = self.get_encoder_decoder_model( config, decoder_config) config = EncoderDecoderConfig.from_encoder_decoder_configs( tied_encoder_model.config, tied_decoder_model.config, tie_encoder_decoder=True) tied_model = EncoderDecoderModel(encoder=tied_encoder_model, decoder=tied_decoder_model, config=config) tied_model.to(torch_device) tied_model.eval() model_result = model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) tied_model_result = tied_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) # check that models has less parameters self.assertLess(sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())) random_slice_idx = ids_tensor((1, ), model_result[0].shape[-1]).item() # check that outputs are equal self.assertTrue( torch.allclose(model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4)) # check that outputs after saving and loading are equal with tempfile.TemporaryDirectory() as tmpdirname: tied_model.save_pretrained(tmpdirname) tied_model = EncoderDecoderModel.from_pretrained(tmpdirname) tied_model.to(torch_device) tied_model.eval() # check that models has less parameters self.assertLess(sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())) random_slice_idx = ids_tensor((1, ), model_result[0].shape[-1]).item() tied_model_result = tied_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) # check that outputs are equal self.assertTrue( torch.allclose(model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4))
class BERT2BERTTrainer(pl.LightningModule): def __init__(self, lr, **args): super(BERT2BERTTrainer, self).__init__() self.save_hyperparameters() encoder = BertGenerationEncoder.from_pretrained( "ckiplab/bert-base-chinese", bos_token_id=101, eos_token_id=102, # force_download=True ) decoder = BertGenerationDecoder.from_pretrained( "ckiplab/bert-base-chinese", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102) self.bert2bert = EncoderDecoderModel(encoder=encoder, decoder=decoder) if args['with_keywords_loss']: self.loss_fct2 = KeywordsLoss(alpha=args['keywords_loss_alpha'], loss_fct=args['keywords_loss_fct']) def generate(self, inputs_ids, attention_mask=None, **kwargs): inputs_ids = inputs_ids.to(self.device) if attention_mask is not None: attention_mask = attention_mask.to(self.device) with torch.no_grad(): return self.bert2bert.generate(input_ids=inputs_ids, attention_mask=attention_mask, bos_token_id=101, min_length=100, eos_token_id=102, pad_token_id=0, **kwargs).detach().cpu().tolist() def forward(self, inputs): with torch.no_grad(): return self.bert2bert(**inputs) def training_step(self, inputs, batch_idx): title, body = inputs title['input_ids'] = title['input_ids'].squeeze(1) title['attention_mask'] = title['attention_mask'].squeeze(1) body['input_ids'] = body['input_ids'].squeeze(1) body['attention_mask'] = body['attention_mask'].squeeze(1) ret = self.bert2bert(input_ids=title['input_ids'], attention_mask=title['attention_mask'], decoder_input_ids=body['input_ids'], decoder_attention_mask=body['attention_mask'], labels=body['input_ids']) loss2 = self.loss_fct2( ret.logits, title['input_ids']) if self.hparams['with_keywords_loss'] else 0. self.log('keyword_loss', loss2, prog_bar=True) self.log('clm_loss', ret.loss, prog_bar=True) return {'loss': ret.loss + loss2, 'keyword_loss': loss2} def training_epoch_end(self, outputs): mean_loss = torch.stack([x['loss'] for x in outputs]).reshape(-1).mean() self.log('mean_loss', mean_loss) def configure_optimizers(self): opt = optim.AdamW(self.bert2bert.parameters(), lr=self.hparams['lr']) return opt @staticmethod def add_parser_args(parser): # parser.add_argument('--lr', type=float) parser.add_argument('--with_keywords_loss', action='store_true') parser.add_argument('--keywords_loss_alpha', type=float, default=0.7, help='float > 0.5') parser.add_argument('--keywords_loss_fct', type=str, default='kldiv', help='kldiv or mse') return parser
pad_token_id = decoder_tokenizer.vocab["[PAD]"], cls_token_id = decoder_tokenizer.vocab["[CLS]"], mask_token_id = decoder_tokenizer.vocab["[MASK]"], bos_token_id = decoder_tokenizer.vocab["[BOS]"], eos_token_id = decoder_tokenizer.vocab["[EOS]"], ) # Initialize a brand new bert-based decoder. decoder = BertGenerationDecoder(config=decoder_config) # Setup enc-decoder mode. bert2bert = EncoderDecoderModel(encoder=encoder, decoder=decoder) bert2bert.config.decoder_start_token_id=decoder_tokenizer.vocab["[CLS]"] bert2bert.config.pad_token_id=decoder_tokenizer.vocab["[PAD]"] # Elementary Training. optimizer = torch.optim.Adam(bert2bert.parameters(), lr=0.000001) bert2bert.cuda() for epoch in range(30): print("*"*50, "Epoch", epoch, "*"*50) if True: for batch in tqdm(sierra_dl): # tokenize commands and goals. inputs = encoder_tokenizer(batch["command"], add_special_tokens=True, return_tensors="pt", padding=True, truncation=True) labels = decoder_tokenizer(batch["symbolic_plan_processed"], return_tensors="pt", padding=True, max_length=sierra_ds.max_plan_length, truncation=True, add_special_tokens=True, ) # Move to GPU. for key,item in inputs.items(): if type(item).__name__ == "Tensor": inputs[key] = item.cuda() for key, item in labels.items():
def train_model(epochs=10, num_gradients_accumulation=4, batch_size=4, gpu_id=0, lr=1e-5, load_dir='/content/BERT checkpoints'): # make sure your model is on GPU device = torch.device(f"cuda:{gpu_id}") # ------------------------LOAD MODEL----------------- print('load the model....') bert_encoder = BertConfig.from_pretrained('bert-base-uncased') bert_decoder = BertConfig.from_pretrained('bert-base-uncased', is_decoder=True) config = EncoderDecoderConfig.from_encoder_decoder_configs( bert_encoder, bert_decoder) model = EncoderDecoderModel(config) model = model.to(device) print('load success') # ------------------------END LOAD MODEL-------------- # ------------------------LOAD TRAIN DATA------------------ train_data = torch.load("/content/train_data.pth") train_dataset = TensorDataset(*train_data) train_dataloader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=batch_size) val_data = torch.load("/content/validate_data.pth") val_dataset = TensorDataset(*val_data) val_dataloader = DataLoader(dataset=val_dataset, shuffle=True, batch_size=batch_size) # ------------------------END LOAD TRAIN DATA-------------- # ------------------------SET OPTIMIZER------------------- num_train_optimization_steps = len( train_dataset) * epochs // batch_size // num_gradients_accumulation param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = AdamW( optimizer_grouped_parameters, lr=lr, weight_decay=0.01, ) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=num_train_optimization_steps // 10, num_training_steps=num_train_optimization_steps) # ------------------------START TRAINING------------------- update_count = 0 start = time.time() print('start training....') for epoch in range(epochs): # ------------------------training------------------------ model.train() losses = 0 times = 0 print('\n' + '-' * 20 + f'epoch {epoch}' + '-' * 20) for batch in tqdm(train_dataloader): batch = [item.to(device) for item in batch] encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch logits = model(input_ids=encoder_input, attention_mask=mask_encoder_input, decoder_input_ids=decoder_input, decoder_attention_mask=mask_decoder_input) out = logits[0][:, :-1].contiguous() target = decoder_input[:, 1:].contiguous() target_mask = mask_decoder_input[:, 1:].contiguous() loss = util.sequence_cross_entropy_with_logits(out, target, target_mask, average="token") loss.backward() losses += loss.item() times += 1 update_count += 1 if update_count % num_gradients_accumulation == num_gradients_accumulation - 1: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() scheduler.step() optimizer.zero_grad() end = time.time() print(f'time: {(end - start)}') print(f'loss: {losses / times}') start = end # ------------------------validate------------------------ model.eval() perplexity = 0 batch_count = 0 print('\nstart calculate the perplexity....') with torch.no_grad(): for batch in tqdm(val_dataloader): batch = [item.to(device) for item in batch] encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch logits = model(input_ids=encoder_input, attention_mask=mask_encoder_input, decoder_input_ids=decoder_input, decoder_attention_mask=mask_decoder_input) out = logits[0][:, :-1].contiguous() target = decoder_input[:, 1:].contiguous() target_mask = mask_decoder_input[:, 1:].contiguous() # print(out.shape,target.shape,target_mask.shape) loss = util.sequence_cross_entropy_with_logits(out, target, target_mask, average="token") perplexity += np.exp(loss.item()) batch_count += 1 print(f'\nvalidate perplexity: {perplexity / batch_count}') torch.save( model.state_dict(), os.path.join(os.path.abspath('.'), load_dir, "model-" + str(epoch) + ".pth"))
decoder = BertForMaskedLM(config=decoder_config) # Define encoder decoder model model = EncoderDecoderModel(encoder=encoder, decoder=decoder) model.to(device) def count_parameters(mdl): return sum(p.numel() for p in mdl.parameters() if p.requires_grad) print(f'The encoder has {count_parameters(encoder):,} trainable parameters') print(f'The decoder has {count_parameters(decoder):,} trainable parameters') print(f'The model has {count_parameters(model):,} trainable parameters') optimizer = optim.Adam(model.parameters(), lr=modelparams['lr']) criterion = nn.NLLLoss(ignore_index=de_tokenizer.pad_token_id) num_train_batches = len(train_dataloader) num_valid_batches = len(valid_dataloader) def compute_loss(predictions, targets): """Compute our custom loss""" predictions = predictions[:, :-1, :].contiguous() targets = targets[:, 1:] rearranged_output = predictions.view( predictions.shape[0] * predictions.shape[1], -1) rearranged_target = targets.contiguous().view(-1)