예제 #1
0
 def test_generate(self):
     input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long()
     config = BartConfig(
         vocab_size=self.vocab_size,
         d_model=24,
         encoder_layers=2,
         decoder_layers=2,
         encoder_attention_heads=2,
         decoder_attention_heads=2,
         encoder_ffn_dim=32,
         decoder_ffn_dim=32,
         max_position_embeddings=48,
         output_past=True,
     )
     lm_model = BartForMaskedLM(config)
     new_input_ids = lm_model.generate(input_ids)
     self.assertEqual(new_input_ids.shape, (input_ids.shape[0], 20))
    def test_generate_beam_search(self):
        input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long()
        config = BartConfig(
            vocab_size=self.vocab_size,
            d_model=24,
            encoder_layers=2,
            decoder_layers=2,
            encoder_attention_heads=2,
            decoder_attention_heads=2,
            encoder_ffn_dim=32,
            decoder_ffn_dim=32,
            max_position_embeddings=48,
            output_past=True,
        )
        lm_model = BartForMaskedLM(config)
        lm_model.eval()

        new_input_ids = lm_model.generate(
            input_ids.clone(), num_return_sequences=1, num_beams=2, no_repeat_ngram_size=3, max_length=5
        )
        self.assertEqual(new_input_ids.shape, (input_ids.shape[0], 5))
예제 #3
0
 def test_generate_fp16(self):
     config, input_ids, batch_size = self._get_config_and_data(
         output_past=True)
     attention_mask = input_ids.ne(1)
     lm_model = BartForMaskedLM(config).eval().to(torch_device).half()
     lm_model.generate(input_ids, attention_mask)