def test_generate_beam_search(self): input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long().to(torch_device) config = BartConfig( vocab_size=self.vocab_size, d_model=24, encoder_layers=2, decoder_layers=2, encoder_attention_heads=2, decoder_attention_heads=2, encoder_ffn_dim=32, decoder_ffn_dim=32, max_position_embeddings=48, eos_token_id=2, pad_token_id=1, bos_token_id=0, ) lm_model = BartForConditionalGeneration(config).to(torch_device) lm_model.eval() max_length = 5 new_input_ids = lm_model.generate( input_ids.clone(), do_sample=True, num_return_sequences=1, num_beams=2, no_repeat_ngram_size=3, max_length=max_length, ) self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length))
def _validate( model: BartForConditionalGeneration, dev_dataloader: DataLoader, logger: logging.Logger, device: torch.device, ): model.eval() loss_sum = 0.0 with torch.no_grad(): for data in tqdm(dev_dataloader): data = _change_device(data, device) output = model.forward( input_ids=data[0], attention_mask=data[1], decoder_input_ids=data[2], labels=data[3], decoder_attention_mask=data[4], return_dict=True, ) loss = output["loss"] loss_sum += loss.item() mean_loss = loss_sum / len(dev_dataloader) logger.info(f"[Validation] Loss {mean_loss:.4f} Perplexity {math.exp(mean_loss):8.2f}") model.train()