def create_and_check_causal_lm_decoder( self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, ): model = ProphetNetForCausalLM(config=config).to(torch_device).eval() outputs = model( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, labels=lm_labels, ) self.parent.assertEqual(len(outputs), 4) self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size)) self.parent.assertEqual(outputs["loss"].size(), ())
def get_encoder_decoder_model(self, config, decoder_config): encoder_model = BertModel(config) decoder_model = ProphetNetForCausalLM(decoder_config) return encoder_model, decoder_model