Exemplo n.º 1
0
    def test_generate_beam_search(self):
        input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long().to(torch_device)
        config = BartConfig(
            vocab_size=self.vocab_size,
            d_model=24,
            encoder_layers=2,
            decoder_layers=2,
            encoder_attention_heads=2,
            decoder_attention_heads=2,
            encoder_ffn_dim=32,
            decoder_ffn_dim=32,
            max_position_embeddings=48,
            eos_token_id=2,
            pad_token_id=1,
            bos_token_id=0,
        )
        lm_model = BartForConditionalGeneration(config).to(torch_device)
        lm_model.eval()

        max_length = 5
        new_input_ids = lm_model.generate(
            input_ids.clone(),
            do_sample=True,
            num_return_sequences=1,
            num_beams=2,
            no_repeat_ngram_size=3,
            max_length=max_length,
        )
        self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length))
Exemplo n.º 2
0
def _validate(
    model: BartForConditionalGeneration,
    dev_dataloader: DataLoader,
    logger: logging.Logger,
    device: torch.device,
):
    model.eval()
    loss_sum = 0.0
    with torch.no_grad():
        for data in tqdm(dev_dataloader):
            data = _change_device(data, device)
            output = model.forward(
                input_ids=data[0],
                attention_mask=data[1],
                decoder_input_ids=data[2],
                labels=data[3],
                decoder_attention_mask=data[4],
                return_dict=True,
            )
            loss = output["loss"]
            loss_sum += loss.item()
    mean_loss = loss_sum / len(dev_dataloader)
    logger.info(f"[Validation] Loss {mean_loss:.4f} Perplexity {math.exp(mean_loss):8.2f}")
    model.train()