def test_lm_forward(self): config, input_ids, batch_size = self._get_config_and_data(output_past=False) decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device) lm_model = BartForConditionalGeneration(config) lm_model.to(torch_device) loss, logits, enc_features = lm_model.forward( input_ids=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids ) expected_shape = (batch_size, input_ids.shape[1], config.vocab_size) self.assertEqual(logits.shape, expected_shape) self.assertIsInstance(loss.item(), float)
def test_lm_uneven_forward(self): config = BartConfig( vocab_size=self.vocab_size, d_model=24, encoder_layers=2, decoder_layers=2, encoder_attention_heads=2, decoder_attention_heads=2, encoder_ffn_dim=32, decoder_ffn_dim=32, max_position_embeddings=48, ) lm_model = BartForConditionalGeneration(config).to(torch_device) context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device) summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device) loss, logits, enc_features = lm_model.forward(input_ids=context, decoder_input_ids=summary, lm_labels=summary) expected_shape = (*summary.shape, config.vocab_size) self.assertEqual(logits.shape, expected_shape)
def train( config: TrainConfig, model: BartForConditionalGeneration, train_dataloader: DataLoader, dev_dataloader: DataLoader, optimizer: Adam, logger: logging.Logger, device=torch.device, ): """ 지정된 Epoch만큼 모델을 학습시키는 함수입니다. """ model.to(device) global_step = 0 for epoch in range(1, config.num_epochs + 1): model.train() loss_sum = 0.0 for data in train_dataloader: global_step += 1 data = _change_device(data, device) optimizer.zero_grad() output = model.forward( input_ids=data[0], attention_mask=data[1], decoder_input_ids=data[2], labels=data[3], decoder_attention_mask=data[4], return_dict=True, ) loss = output["loss"] loss.backward() loss_sum += loss.item() nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() if global_step % config.train_log_interval == 0: mean_loss = loss_sum / config.train_log_interval logger.info( f"Epoch {epoch} Step {global_step} " f"Loss {mean_loss:.4f} Perplexity {math.exp(mean_loss):8.2f}" ) loss_sum = 0.0 if global_step % config.dev_log_interval == 0: _validate(model, dev_dataloader, logger, device) if global_step % config.save_interval == 0: model.save_pretrained(f"{config.save_model_file_prefix}_{global_step}")
def _validate( model: BartForConditionalGeneration, dev_dataloader: DataLoader, logger: logging.Logger, device: torch.device, ): model.eval() loss_sum = 0.0 with torch.no_grad(): for data in tqdm(dev_dataloader): data = _change_device(data, device) output = model.forward( input_ids=data[0], attention_mask=data[1], decoder_input_ids=data[2], labels=data[3], decoder_attention_mask=data[4], return_dict=True, ) loss = output["loss"] loss_sum += loss.item() mean_loss = loss_sum / len(dev_dataloader) logger.info(f"[Validation] Loss {mean_loss:.4f} Perplexity {math.exp(mean_loss):8.2f}") model.train()