示例#1
0
    def test_generation_from_short_input_same_as_parlai_3B(self):
        FASTER_GEN_KWARGS = dict(num_beams=1, early_stopping=True, min_length=15, max_length=25)
        TOK_DECODE_KW = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)

        model = FlaxBlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-3B", from_pt=True)
        tokenizer = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-3B")

        src_text = ["Sam"]
        model_inputs = tokenizer(src_text, return_tensors="jax")

        generated_utterances = model.generate(**model_inputs, **FASTER_GEN_KWARGS)
        tgt_text = 'Sam is a great name. It means "sun" in Gaelic.'

        generated_txt = tokenizer.batch_decode(generated_utterances, **TOK_DECODE_KW)
        assert generated_txt[0].strip() == tgt_text
示例#2
0
 def test_lm_uneven_forward(self):
     config = BlenderbotConfig(
         vocab_size=self.vocab_size,
         d_model=14,
         encoder_layers=2,
         decoder_layers=2,
         encoder_attention_heads=2,
         decoder_attention_heads=2,
         encoder_ffn_dim=8,
         decoder_ffn_dim=8,
         max_position_embeddings=48,
     )
     lm_model = FlaxBlenderbotForConditionalGeneration(config)
     context = np.array([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]], dtype=np.int64)
     summary = np.array([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]], dtype=np.int64)
     outputs = lm_model(input_ids=context, decoder_input_ids=summary)
     expected_shape = (*summary.shape, config.vocab_size)
     self.assertEqual(outputs["logits"].shape, expected_shape)
 def test_lm_forward(self):
     config, input_ids, batch_size = self._get_config_and_data()
     lm_model = FlaxBlenderbotForConditionalGeneration(config)
     outputs = lm_model(input_ids=input_ids)
     expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
     self.assertEqual(outputs["logits"].shape, expected_shape)