def test_generation_from_short_input_same_as_parlai_3B(self): FASTER_GEN_KWARGS = dict(num_beams=1, early_stopping=True, min_length=15, max_length=25) TOK_DECODE_KW = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) model = FlaxBlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-3B", from_pt=True) tokenizer = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-3B") src_text = ["Sam"] model_inputs = tokenizer(src_text, return_tensors="jax") generated_utterances = model.generate(**model_inputs, **FASTER_GEN_KWARGS) tgt_text = 'Sam is a great name. It means "sun" in Gaelic.' generated_txt = tokenizer.batch_decode(generated_utterances, **TOK_DECODE_KW) assert generated_txt[0].strip() == tgt_text
def test_lm_uneven_forward(self): config = BlenderbotConfig( vocab_size=self.vocab_size, d_model=14, encoder_layers=2, decoder_layers=2, encoder_attention_heads=2, decoder_attention_heads=2, encoder_ffn_dim=8, decoder_ffn_dim=8, max_position_embeddings=48, ) lm_model = FlaxBlenderbotForConditionalGeneration(config) context = np.array([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]], dtype=np.int64) summary = np.array([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]], dtype=np.int64) outputs = lm_model(input_ids=context, decoder_input_ids=summary) expected_shape = (*summary.shape, config.vocab_size) self.assertEqual(outputs["logits"].shape, expected_shape)
def test_lm_forward(self): config, input_ids, batch_size = self._get_config_and_data() lm_model = FlaxBlenderbotForConditionalGeneration(config) outputs = lm_model(input_ids=input_ids) expected_shape = (batch_size, input_ids.shape[1], config.vocab_size) self.assertEqual(outputs["logits"].shape, expected_shape)