def test_generate(self): input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long() config = BartConfig( vocab_size=self.vocab_size, d_model=24, encoder_layers=2, decoder_layers=2, encoder_attention_heads=2, decoder_attention_heads=2, encoder_ffn_dim=32, decoder_ffn_dim=32, max_position_embeddings=48, output_past=True, ) lm_model = BartForMaskedLM(config) new_input_ids = lm_model.generate(input_ids) self.assertEqual(new_input_ids.shape, (input_ids.shape[0], 20))
def test_generate_beam_search(self): input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long() config = BartConfig( vocab_size=self.vocab_size, d_model=24, encoder_layers=2, decoder_layers=2, encoder_attention_heads=2, decoder_attention_heads=2, encoder_ffn_dim=32, decoder_ffn_dim=32, max_position_embeddings=48, output_past=True, ) lm_model = BartForMaskedLM(config) lm_model.eval() new_input_ids = lm_model.generate( input_ids.clone(), num_return_sequences=1, num_beams=2, no_repeat_ngram_size=3, max_length=5 ) self.assertEqual(new_input_ids.shape, (input_ids.shape[0], 5))
def test_generate_fp16(self): config, input_ids, batch_size = self._get_config_and_data( output_past=True) attention_mask = input_ids.ne(1) lm_model = BartForMaskedLM(config).eval().to(torch_device).half() lm_model.generate(input_ids, attention_mask)