示例#1
0
 def create_and_check_causal_lm_decoder(
     self,
     config,
     input_ids,
     decoder_input_ids,
     attention_mask,
     decoder_attention_mask,
     lm_labels,
 ):
     model = ProphetNetForCausalLM(config=config).to(torch_device).eval()
     outputs = model(
         input_ids=decoder_input_ids,
         attention_mask=decoder_attention_mask,
         labels=lm_labels,
     )
     self.parent.assertEqual(len(outputs), 4)
     self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size))
     self.parent.assertEqual(outputs["loss"].size(), ())
 def get_encoder_decoder_model(self, config, decoder_config):
     encoder_model = BertModel(config)
     decoder_model = ProphetNetForCausalLM(decoder_config)
     return encoder_model, decoder_model