def prepare_config_and_inputs(self):
        input_ids = np.clip(
            ids_tensor([self.batch_size, self.seq_length - 1],
                       self.vocab_size), 3, self.vocab_size)
        input_ids = np.concatenate((input_ids, 2 * np.ones(
            (self.batch_size, 1), dtype=np.int64)), -1)

        decoder_input_ids = shift_tokens_right(input_ids, 1)

        config = MBartConfig(
            vocab_size=self.vocab_size,
            d_model=self.hidden_size,
            encoder_layers=self.num_hidden_layers,
            decoder_layers=self.num_hidden_layers,
            encoder_attention_heads=self.num_attention_heads,
            decoder_attention_heads=self.num_attention_heads,
            encoder_ffn_dim=self.intermediate_size,
            decoder_ffn_dim=self.intermediate_size,
            dropout=self.hidden_dropout_prob,
            attention_dropout=self.attention_probs_dropout_prob,
            max_position_embeddings=self.max_position_embeddings,
            eos_token_id=self.eos_token_id,
            bos_token_id=self.bos_token_id,
            pad_token_id=self.pad_token_id,
            decoder_start_token_id=self.decoder_start_token_id,
            initializer_range=self.initializer_range,
            use_cache=False,
        )
        inputs_dict = prepare_mbart_inputs_dict(config, input_ids,
                                                decoder_input_ids)
        return config, inputs_dict
 def test_shift_tokens_right(self):
     input_ids = np.array([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]], dtype=np.int64)
     shifted = shift_tokens_right(input_ids, 1)
     n_pad_before = np.equal(input_ids, 1).astype(np.float32).sum()
     n_pad_after = np.equal(shifted, 1).astype(np.float32).sum()
     self.assertEqual(shifted.shape, input_ids.shape)
     self.assertEqual(n_pad_after, n_pad_before - 1)
     self.assertTrue(np.equal(shifted[:, 0], 2).all())