def test_generate(self):
     input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long()
     config = BartConfig(
         vocab_size=self.vocab_size,
         d_model=24,
         encoder_layers=2,
         decoder_layers=2,
         encoder_attention_heads=2,
         decoder_attention_heads=2,
         encoder_ffn_dim=32,
         decoder_ffn_dim=32,
         max_position_embeddings=48,
         output_past=True,
     )
     lm_model = BartForMaskedLM(config)
     lm_model.eval()
     new_input_ids = lm_model.generate(input_ids)
     self.assertEqual(new_input_ids.shape, (input_ids.shape[0], 20))
    def test_generate_beam_search(self):
        input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long()
        config = BartConfig(
            vocab_size=self.vocab_size,
            d_model=24,
            encoder_layers=2,
            decoder_layers=2,
            encoder_attention_heads=2,
            decoder_attention_heads=2,
            encoder_ffn_dim=32,
            decoder_ffn_dim=32,
            max_position_embeddings=48,
            output_past=True,
        )
        lm_model = BartForMaskedLM(config)
        lm_model.eval()

        new_input_ids = lm_model.generate(
            input_ids.clone(), num_return_sequences=1, num_beams=2, no_repeat_ngram_size=3, max_length=5
        )
        self.assertEqual(new_input_ids.shape, (input_ids.shape[0], 5))