def test_lm_uneven_forward(self): config = MBartConfig( vocab_size=self.vocab_size, d_model=14, encoder_layers=2, decoder_layers=2, encoder_attention_heads=2, decoder_attention_heads=2, encoder_ffn_dim=8, decoder_ffn_dim=8, max_position_embeddings=48, ) lm_model = FlaxMBartForConditionalGeneration(config) context = np.array([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]], dtype=np.int64) summary = np.array([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]], dtype=np.int64) outputs = lm_model(input_ids=context, decoder_input_ids=summary) expected_shape = (*summary.shape, config.vocab_size) self.assertEqual(outputs["logits"].shape, expected_shape)
def model(self): model = FlaxMBartForConditionalGeneration.from_pretrained( self.model_name, from_pt=True) return model
def test_lm_forward(self): config, input_ids, batch_size = self._get_config_and_data() lm_model = FlaxMBartForConditionalGeneration(config) outputs = lm_model(input_ids=input_ids) expected_shape = (batch_size, input_ids.shape[1], config.vocab_size) self.assertEqual(outputs["logits"].shape, expected_shape)