예제 #1
0
 def test_lm_forward(self):
     config, input_ids, batch_size = self._get_config_and_data(output_past=False)
     decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
     lm_model = BartForConditionalGeneration(config)
     lm_model.to(torch_device)
     loss, logits, enc_features = lm_model.forward(
         input_ids=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids
     )
     expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
     self.assertEqual(logits.shape, expected_shape)
     self.assertIsInstance(loss.item(), float)
예제 #2
0
 def test_lm_uneven_forward(self):
     config = BartConfig(
         vocab_size=self.vocab_size,
         d_model=24,
         encoder_layers=2,
         decoder_layers=2,
         encoder_attention_heads=2,
         decoder_attention_heads=2,
         encoder_ffn_dim=32,
         decoder_ffn_dim=32,
         max_position_embeddings=48,
     )
     lm_model = BartForConditionalGeneration(config).to(torch_device)
     context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
     summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
     loss, logits, enc_features = lm_model.forward(input_ids=context, decoder_input_ids=summary, lm_labels=summary)
     expected_shape = (*summary.shape, config.vocab_size)
     self.assertEqual(logits.shape, expected_shape)
예제 #3
0
def train(
    config: TrainConfig,
    model: BartForConditionalGeneration,
    train_dataloader: DataLoader,
    dev_dataloader: DataLoader,
    optimizer: Adam,
    logger: logging.Logger,
    device=torch.device,
):
    """ 지정된 Epoch만큼 모델을 학습시키는 함수입니다. """
    model.to(device)
    global_step = 0
    for epoch in range(1, config.num_epochs + 1):
        model.train()
        loss_sum = 0.0
        for data in train_dataloader:
            global_step += 1
            data = _change_device(data, device)
            optimizer.zero_grad()
            output = model.forward(
                input_ids=data[0],
                attention_mask=data[1],
                decoder_input_ids=data[2],
                labels=data[3],
                decoder_attention_mask=data[4],
                return_dict=True,
            )
            loss = output["loss"]
            loss.backward()
            loss_sum += loss.item()

            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            if global_step % config.train_log_interval == 0:
                mean_loss = loss_sum / config.train_log_interval
                logger.info(
                    f"Epoch {epoch} Step {global_step} " f"Loss {mean_loss:.4f} Perplexity {math.exp(mean_loss):8.2f}"
                )
                loss_sum = 0.0
            if global_step % config.dev_log_interval == 0:
                _validate(model, dev_dataloader, logger, device)
            if global_step % config.save_interval == 0:
                model.save_pretrained(f"{config.save_model_file_prefix}_{global_step}")
예제 #4
0
def _validate(
    model: BartForConditionalGeneration,
    dev_dataloader: DataLoader,
    logger: logging.Logger,
    device: torch.device,
):
    model.eval()
    loss_sum = 0.0
    with torch.no_grad():
        for data in tqdm(dev_dataloader):
            data = _change_device(data, device)
            output = model.forward(
                input_ids=data[0],
                attention_mask=data[1],
                decoder_input_ids=data[2],
                labels=data[3],
                decoder_attention_mask=data[4],
                return_dict=True,
            )
            loss = output["loss"]
            loss_sum += loss.item()
    mean_loss = loss_sum / len(dev_dataloader)
    logger.info(f"[Validation] Loss {mean_loss:.4f} Perplexity {math.exp(mean_loss):8.2f}")
    model.train()