def test_lm_uneven_forward(self):
     config = MBartConfig(
         vocab_size=self.vocab_size,
         d_model=14,
         encoder_layers=2,
         decoder_layers=2,
         encoder_attention_heads=2,
         decoder_attention_heads=2,
         encoder_ffn_dim=8,
         decoder_ffn_dim=8,
         max_position_embeddings=48,
     )
     lm_model = FlaxMBartForConditionalGeneration(config)
     context = np.array([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]], dtype=np.int64)
     summary = np.array([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]], dtype=np.int64)
     outputs = lm_model(input_ids=context, decoder_input_ids=summary)
     expected_shape = (*summary.shape, config.vocab_size)
     self.assertEqual(outputs["logits"].shape, expected_shape)
 def model(self):
     model = FlaxMBartForConditionalGeneration.from_pretrained(
         self.model_name, from_pt=True)
     return model
 def test_lm_forward(self):
     config, input_ids, batch_size = self._get_config_and_data()
     lm_model = FlaxMBartForConditionalGeneration(config)
     outputs = lm_model(input_ids=input_ids)
     expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
     self.assertEqual(outputs["logits"].shape, expected_shape)